Migrate REST switch to async (#4517)

* Migrate REST switch to async

* Update rest.py

* Address comments from paulus
This commit is contained in:
Pascal Vizeli 2016-12-13 17:55:13 +01:00 committed by Paulus Schoutsen
parent 72bd9fb5c7
commit e4b6395250
3 changed files with 204 additions and 137 deletions

View File

@ -4,23 +4,27 @@ Support for RESTful switches.
For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/switch.rest/
"""
import asyncio
import logging
import requests
import aiohttp
import async_timeout
import voluptuous as vol
from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA)
from homeassistant.const import (CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT)
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.template import Template
CONF_BODY_OFF = 'body_off'
CONF_BODY_ON = 'body_on'
CONF_IS_ON_TEMPLATE = 'is_on_template'
DEFAULT_BODY_OFF = Template('OFF')
DEFAULT_BODY_ON = Template('ON')
DEFAULT_NAME = 'REST Switch'
DEFAULT_TIMEOUT = 10
CONF_IS_ON_TEMPLATE = 'is_on_template'
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_RESOURCE): cv.url,
@ -35,13 +39,15 @@ _LOGGER = logging.getLogger(__name__)
# pylint: disable=unused-argument,
def setup_platform(hass, config, add_devices, discovery_info=None):
@asyncio.coroutine
def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
"""Set up the RESTful switch."""
name = config.get(CONF_NAME)
resource = config.get(CONF_RESOURCE)
body_on = config.get(CONF_BODY_ON)
body_off = config.get(CONF_BODY_OFF)
is_on_template = config.get(CONF_IS_ON_TEMPLATE)
websession = async_get_clientsession(hass)
if is_on_template is not None:
is_on_template.hass = hass
@ -51,19 +57,24 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
body_off.hass = hass
timeout = config.get(CONF_TIMEOUT)
req = None
try:
requests.get(resource, timeout=10)
except requests.exceptions.MissingSchema:
with async_timeout.timeout(timeout, loop=hass.loop):
req = yield from websession.get(resource)
except (TypeError, ValueError):
_LOGGER.error("Missing resource or schema in configuration. "
"Add http:// or https:// to your URL")
return False
except requests.exceptions.ConnectionError:
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.error("No route to resource/endpoint: %s", resource)
return False
finally:
if req is not None:
yield from req.release()
add_devices(
[RestSwitch(
hass, name, resource, body_on, body_off, is_on_template, timeout)])
yield from async_add_devices(
[RestSwitch(hass, name, resource, body_on, body_off,
is_on_template, timeout)])
class RestSwitch(SwitchDevice):
@ -73,7 +84,7 @@ class RestSwitch(SwitchDevice):
is_on_template, timeout):
"""Initialize the REST switch."""
self._state = None
self._hass = hass
self.hass = hass
self._name = name
self._resource = resource
self._body_on = body_on
@ -91,46 +102,85 @@ class RestSwitch(SwitchDevice):
"""Return true if device is on."""
return self._state
def turn_on(self, **kwargs):
@asyncio.coroutine
def async_turn_on(self, **kwargs):
"""Turn the device on."""
body_on_t = self._body_on.render()
request = requests.post(
self._resource, data=body_on_t, timeout=self._timeout)
if request.status_code == 200:
body_on_t = self._body_on.async_render()
websession = async_get_clientsession(self.hass)
request = None
try:
with async_timeout.timeout(self._timeout, loop=self.hass.loop):
request = yield from websession.post(
self._resource, data=bytes(body_on_t, 'utf-8'))
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.error("Error while turn on %s", self._resource)
return
finally:
if request is not None:
yield from request.release()
if request.status == 200:
self._state = True
else:
_LOGGER.error("Can't turn on %s. Is resource/endpoint offline?",
self._resource)
def turn_off(self, **kwargs):
@asyncio.coroutine
def async_turn_off(self, **kwargs):
"""Turn the device off."""
body_off_t = self._body_off.render()
request = requests.post(
self._resource, data=body_off_t, timeout=self._timeout)
if request.status_code == 200:
body_off_t = self._body_off.async_render()
websession = async_get_clientsession(self.hass)
request = None
try:
with async_timeout.timeout(self._timeout, loop=self.hass.loop):
request = yield from websession.post(
self._resource, data=bytes(body_off_t, 'utf-8'))
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.error("Error while turn off %s", self._resource)
return
finally:
if request is not None:
yield from request.release()
if request.status == 200:
self._state = False
else:
_LOGGER.error("Can't turn off %s. Is resource/endpoint offline?",
self._resource)
def update(self):
@asyncio.coroutine
def async_update(self):
"""Get the latest data from REST API and update the state."""
request = requests.get(self._resource, timeout=self._timeout)
websession = async_get_clientsession(self.hass)
request = None
try:
with async_timeout.timeout(self._timeout, loop=self.hass.loop):
request = yield from websession.get(self._resource)
text = yield from request.text()
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.exception("Error while fetch data.")
return
finally:
if request is not None:
yield from request.release()
if self._is_on_template is not None:
response = self._is_on_template.render_with_possible_json_value(
request.text, 'None')
response = response.lower()
if response == 'true':
text = self._is_on_template.async_render_with_possible_json_value(
text, 'None')
text = text.lower()
if text == 'true':
self._state = True
elif response == 'false':
elif text == 'false':
self._state = False
else:
self._state = None
else:
if request.text == self._body_on.template:
if text == self._body_on.template:
self._state = True
elif request.text == self._body_off.template:
elif text == self._body_off.template:
self._state = False
else:
self._state = None

View File

@ -1,76 +1,83 @@
"""The tests for the REST switch platform."""
import unittest
from unittest.mock import patch
import asyncio
import pytest
import requests
from requests.exceptions import Timeout
import requests_mock
import aiohttp
import homeassistant.components.switch.rest as rest
from homeassistant.bootstrap import setup_component
from homeassistant.util.async import run_coroutine_threadsafe
from homeassistant.helpers.template import Template
from tests.common import get_test_home_assistant, assert_setup_component
class TestRestSwitchSetup(unittest.TestCase):
class TestRestSwitchSetup:
"""Tests for setting up the REST switch platform."""
def setUp(self):
def setup_method(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
def tearDown(self):
def teardown_method(self):
"""Stop everything that was started."""
self.hass.stop()
def test_setup_missing_config(self):
"""Test setup with configuration missing required entries."""
self.assertFalse(rest.setup_platform(self.hass, {
'platform': 'rest'
}, None))
assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest'
}, None),
self.hass.loop
).result()
def test_setup_missing_schema(self):
"""Test setup with resource missing schema."""
self.assertFalse(rest.setup_platform(self.hass, {
'platform': 'rest',
'resource': 'localhost'
}, None))
assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest',
'resource': 'localhost'
}, None),
self.hass.loop
).result()
@patch('requests.get', side_effect=requests.exceptions.ConnectionError())
def test_setup_failed_connect(self, mock_req):
def test_setup_failed_connect(self, aioclient_mock):
"""Test setup when connection error occurs."""
self.assertFalse(rest.setup_platform(self.hass, {
'platform': 'rest',
'resource': 'http://localhost',
}, None))
@patch('requests.get', side_effect=Timeout())
def test_setup_timeout(self, mock_req):
"""Test setup when connection timeout occurs."""
with self.assertRaises(Timeout):
rest.setup_platform(self.hass, {
aioclient_mock.get('http://localhost', exc=aiohttp.errors.ClientError)
assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest',
'resource': 'http://localhost',
}, None)
}, None),
self.hass.loop
).result()
@requests_mock.Mocker()
def test_setup_minimum(self, mock_req):
"""Test setup with minimum configuration."""
mock_req.get('http://localhost', status_code=200)
self.assertTrue(setup_component(self.hass, 'switch', {
'switch': {
def test_setup_timeout(self, aioclient_mock):
"""Test setup when connection timeout occurs."""
aioclient_mock.get('http://localhost', exc=asyncio.TimeoutError())
assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest',
'resource': 'http://localhost'
}
}))
self.assertEqual(1, mock_req.call_count)
assert_setup_component(1, 'switch')
'resource': 'http://localhost',
}, None),
self.hass.loop
).result()
@requests_mock.Mocker()
def test_setup(self, mock_req):
def test_setup_minimum(self, aioclient_mock):
"""Test setup with minimum configuration."""
aioclient_mock.get('http://localhost', status=200)
with assert_setup_component(1, 'switch'):
assert setup_component(self.hass, 'switch', {
'switch': {
'platform': 'rest',
'resource': 'http://localhost'
}
})
assert aioclient_mock.call_count == 1
def test_setup(self, aioclient_mock):
"""Test setup with valid configuration."""
mock_req.get('localhost', status_code=200)
self.assertTrue(setup_component(self.hass, 'switch', {
aioclient_mock.get('http://localhost', status=200)
assert setup_component(self.hass, 'switch', {
'switch': {
'platform': 'rest',
'name': 'foo',
@ -78,111 +85,120 @@ class TestRestSwitchSetup(unittest.TestCase):
'body_on': 'custom on text',
'body_off': 'custom off text',
}
}))
self.assertEqual(1, mock_req.call_count)
})
assert aioclient_mock.call_count == 1
assert_setup_component(1, 'switch')
@pytest.mark.skip
class TestRestSwitch(unittest.TestCase):
class TestRestSwitch:
"""Tests for REST switch platform."""
def setUp(self):
def setup_method(self):
"""Setup things to be run when tests are started."""
self.hass = get_test_home_assistant()
self.name = 'foo'
self.resource = 'http://localhost/'
self.body_on = 'on'
self.body_off = 'off'
self.body_on = Template('on', self.hass)
self.body_off = Template('off', self.hass)
self.switch = rest.RestSwitch(self.hass, self.name, self.resource,
self.body_on, self.body_off)
self.body_on, self.body_off, None, 10)
def tearDown(self):
def teardown_method(self):
"""Stop everything that was started."""
self.hass.stop()
def test_name(self):
"""Test the name."""
self.assertEqual(self.name, self.switch.name)
assert self.name == self.switch.name
def test_is_on_before_update(self):
"""Test is_on in initial state."""
self.assertEqual(None, self.switch.is_on)
assert self.switch.is_on is None
@requests_mock.Mocker()
def test_turn_on_success(self, mock_req):
def test_turn_on_success(self, aioclient_mock):
"""Test turn_on."""
mock_req.post(self.resource, status_code=200)
self.switch.turn_on()
aioclient_mock.post(self.resource, status=200)
run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
self.assertEqual(self.body_on, mock_req.last_request.text)
self.assertEqual(True, self.switch.is_on)
assert self.body_on.template == \
aioclient_mock.mock_calls[-1][2].decode()
assert self.switch.is_on
@requests_mock.Mocker()
def test_turn_on_status_not_ok(self, mock_req):
def test_turn_on_status_not_ok(self, aioclient_mock):
"""Test turn_on when error status returned."""
mock_req.post(self.resource, status_code=500)
self.switch.turn_on()
aioclient_mock.post(self.resource, status=500)
run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
self.assertEqual(self.body_on, mock_req.last_request.text)
self.assertEqual(None, self.switch.is_on)
assert self.body_on.template == \
aioclient_mock.mock_calls[-1][2].decode()
assert self.switch.is_on is None
@patch('requests.post', side_effect=Timeout())
def test_turn_on_timeout(self, mock_req):
def test_turn_on_timeout(self, aioclient_mock):
"""Test turn_on when timeout occurs."""
with self.assertRaises(Timeout):
self.switch.turn_on()
aioclient_mock.post(self.resource, status=500)
run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
@requests_mock.Mocker()
def test_turn_off_success(self, mock_req):
assert self.switch.is_on is None
def test_turn_off_success(self, aioclient_mock):
"""Test turn_off."""
mock_req.post(self.resource, status_code=200)
self.switch.turn_off()
aioclient_mock.post(self.resource, status=200)
run_coroutine_threadsafe(
self.switch.async_turn_off(), self.hass.loop).result()
self.assertEqual(self.body_off, mock_req.last_request.text)
self.assertEqual(False, self.switch.is_on)
assert self.body_off.template == \
aioclient_mock.mock_calls[-1][2].decode()
assert not self.switch.is_on
@requests_mock.Mocker()
def test_turn_off_status_not_ok(self, mock_req):
def test_turn_off_status_not_ok(self, aioclient_mock):
"""Test turn_off when error status returned."""
mock_req.post(self.resource, status_code=500)
self.switch.turn_off()
aioclient_mock.post(self.resource, status=500)
run_coroutine_threadsafe(
self.switch.async_turn_off(), self.hass.loop).result()
self.assertEqual(self.body_off, mock_req.last_request.text)
self.assertEqual(None, self.switch.is_on)
assert self.body_off.template == \
aioclient_mock.mock_calls[-1][2].decode()
assert self.switch.is_on is None
@patch('requests.post', side_effect=Timeout())
def test_turn_off_timeout(self, mock_req):
def test_turn_off_timeout(self, aioclient_mock):
"""Test turn_off when timeout occurs."""
with self.assertRaises(Timeout):
self.switch.turn_on()
aioclient_mock.post(self.resource, exc=asyncio.TimeoutError())
run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
@requests_mock.Mocker()
def test_update_when_on(self, mock_req):
assert self.switch.is_on is None
def test_update_when_on(self, aioclient_mock):
"""Test update when switch is on."""
mock_req.get(self.resource, text=self.body_on)
self.switch.update()
aioclient_mock.get(self.resource, text=self.body_on.template)
run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
self.assertEqual(True, self.switch.is_on)
assert self.switch.is_on
@requests_mock.Mocker()
def test_update_when_off(self, mock_req):
def test_update_when_off(self, aioclient_mock):
"""Test update when switch is off."""
mock_req.get(self.resource, text=self.body_off)
self.switch.update()
aioclient_mock.get(self.resource, text=self.body_off.template)
run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
self.assertEqual(False, self.switch.is_on)
assert not self.switch.is_on
@requests_mock.Mocker()
def test_update_when_unknown(self, mock_req):
def test_update_when_unknown(self, aioclient_mock):
"""Test update when unknown status returned."""
mock_req.get(self.resource, text='unknown status')
self.switch.update()
aioclient_mock.get(self.resource, text='unknown status')
run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
self.assertEqual(None, self.switch.is_on)
assert self.switch.is_on is None
@patch('requests.get', side_effect=Timeout())
def test_update_timeout(self, mock_req):
def test_update_timeout(self, aioclient_mock):
"""Test update when timeout occurs."""
with self.assertRaises(Timeout):
self.switch.update()
aioclient_mock.get(self.resource, exc=asyncio.TimeoutError())
run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
assert self.switch.is_on is None

View File

@ -20,6 +20,7 @@ class AiohttpClientMocker:
auth=None,
status=200,
text=None,
data=None,
content=None,
json=None,
params=None,
@ -66,12 +67,12 @@ class AiohttpClientMocker:
return len(self.mock_calls)
@asyncio.coroutine
def match_request(self, method, url, *, auth=None, params=None,
def match_request(self, method, url, *, data=None, auth=None, params=None,
headers=None): # pylint: disable=unused-variable
"""Match a request against pre-registered requests."""
for response in self._mocks:
if response.match_request(method, url, params):
self.mock_calls.append((method, url))
self.mock_calls.append((method, url, data))
if self.exc:
raise self.exc