Migrate REST switch to async (#4517)

* Migrate REST switch to async

* Update rest.py

* Address comments from paulus
This commit is contained in:
Pascal Vizeli 2016-12-13 17:55:13 +01:00 committed by Paulus Schoutsen
parent 72bd9fb5c7
commit e4b6395250
3 changed files with 204 additions and 137 deletions

View File

@ -4,23 +4,27 @@ Support for RESTful switches.
For more details about this platform, please refer to the documentation at For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/switch.rest/ https://home-assistant.io/components/switch.rest/
""" """
import asyncio
import logging import logging
import requests import aiohttp
import async_timeout
import voluptuous as vol import voluptuous as vol
from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA)
from homeassistant.const import (CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT) from homeassistant.const import (CONF_NAME, CONF_RESOURCE, CONF_TIMEOUT)
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
CONF_BODY_OFF = 'body_off' CONF_BODY_OFF = 'body_off'
CONF_BODY_ON = 'body_on' CONF_BODY_ON = 'body_on'
CONF_IS_ON_TEMPLATE = 'is_on_template'
DEFAULT_BODY_OFF = Template('OFF') DEFAULT_BODY_OFF = Template('OFF')
DEFAULT_BODY_ON = Template('ON') DEFAULT_BODY_ON = Template('ON')
DEFAULT_NAME = 'REST Switch' DEFAULT_NAME = 'REST Switch'
DEFAULT_TIMEOUT = 10 DEFAULT_TIMEOUT = 10
CONF_IS_ON_TEMPLATE = 'is_on_template'
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_RESOURCE): cv.url, vol.Required(CONF_RESOURCE): cv.url,
@ -35,13 +39,15 @@ _LOGGER = logging.getLogger(__name__)
# pylint: disable=unused-argument, # pylint: disable=unused-argument,
def setup_platform(hass, config, add_devices, discovery_info=None): @asyncio.coroutine
def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
"""Set up the RESTful switch.""" """Set up the RESTful switch."""
name = config.get(CONF_NAME) name = config.get(CONF_NAME)
resource = config.get(CONF_RESOURCE) resource = config.get(CONF_RESOURCE)
body_on = config.get(CONF_BODY_ON) body_on = config.get(CONF_BODY_ON)
body_off = config.get(CONF_BODY_OFF) body_off = config.get(CONF_BODY_OFF)
is_on_template = config.get(CONF_IS_ON_TEMPLATE) is_on_template = config.get(CONF_IS_ON_TEMPLATE)
websession = async_get_clientsession(hass)
if is_on_template is not None: if is_on_template is not None:
is_on_template.hass = hass is_on_template.hass = hass
@ -51,19 +57,24 @@ def setup_platform(hass, config, add_devices, discovery_info=None):
body_off.hass = hass body_off.hass = hass
timeout = config.get(CONF_TIMEOUT) timeout = config.get(CONF_TIMEOUT)
req = None
try: try:
requests.get(resource, timeout=10) with async_timeout.timeout(timeout, loop=hass.loop):
except requests.exceptions.MissingSchema: req = yield from websession.get(resource)
except (TypeError, ValueError):
_LOGGER.error("Missing resource or schema in configuration. " _LOGGER.error("Missing resource or schema in configuration. "
"Add http:// or https:// to your URL") "Add http:// or https:// to your URL")
return False return False
except requests.exceptions.ConnectionError: except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.error("No route to resource/endpoint: %s", resource) _LOGGER.error("No route to resource/endpoint: %s", resource)
return False return False
finally:
if req is not None:
yield from req.release()
add_devices( yield from async_add_devices(
[RestSwitch( [RestSwitch(hass, name, resource, body_on, body_off,
hass, name, resource, body_on, body_off, is_on_template, timeout)]) is_on_template, timeout)])
class RestSwitch(SwitchDevice): class RestSwitch(SwitchDevice):
@ -73,7 +84,7 @@ class RestSwitch(SwitchDevice):
is_on_template, timeout): is_on_template, timeout):
"""Initialize the REST switch.""" """Initialize the REST switch."""
self._state = None self._state = None
self._hass = hass self.hass = hass
self._name = name self._name = name
self._resource = resource self._resource = resource
self._body_on = body_on self._body_on = body_on
@ -91,46 +102,85 @@ class RestSwitch(SwitchDevice):
"""Return true if device is on.""" """Return true if device is on."""
return self._state return self._state
def turn_on(self, **kwargs): @asyncio.coroutine
def async_turn_on(self, **kwargs):
"""Turn the device on.""" """Turn the device on."""
body_on_t = self._body_on.render() body_on_t = self._body_on.async_render()
request = requests.post( websession = async_get_clientsession(self.hass)
self._resource, data=body_on_t, timeout=self._timeout)
if request.status_code == 200: request = None
try:
with async_timeout.timeout(self._timeout, loop=self.hass.loop):
request = yield from websession.post(
self._resource, data=bytes(body_on_t, 'utf-8'))
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.error("Error while turn on %s", self._resource)
return
finally:
if request is not None:
yield from request.release()
if request.status == 200:
self._state = True self._state = True
else: else:
_LOGGER.error("Can't turn on %s. Is resource/endpoint offline?", _LOGGER.error("Can't turn on %s. Is resource/endpoint offline?",
self._resource) self._resource)
def turn_off(self, **kwargs): @asyncio.coroutine
def async_turn_off(self, **kwargs):
"""Turn the device off.""" """Turn the device off."""
body_off_t = self._body_off.render() body_off_t = self._body_off.async_render()
request = requests.post( websession = async_get_clientsession(self.hass)
self._resource, data=body_off_t, timeout=self._timeout)
if request.status_code == 200: request = None
try:
with async_timeout.timeout(self._timeout, loop=self.hass.loop):
request = yield from websession.post(
self._resource, data=bytes(body_off_t, 'utf-8'))
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.error("Error while turn off %s", self._resource)
return
finally:
if request is not None:
yield from request.release()
if request.status == 200:
self._state = False self._state = False
else: else:
_LOGGER.error("Can't turn off %s. Is resource/endpoint offline?", _LOGGER.error("Can't turn off %s. Is resource/endpoint offline?",
self._resource) self._resource)
def update(self): @asyncio.coroutine
def async_update(self):
"""Get the latest data from REST API and update the state.""" """Get the latest data from REST API and update the state."""
request = requests.get(self._resource, timeout=self._timeout) websession = async_get_clientsession(self.hass)
request = None
try:
with async_timeout.timeout(self._timeout, loop=self.hass.loop):
request = yield from websession.get(self._resource)
text = yield from request.text()
except (asyncio.TimeoutError, aiohttp.errors.ClientError):
_LOGGER.exception("Error while fetch data.")
return
finally:
if request is not None:
yield from request.release()
if self._is_on_template is not None: if self._is_on_template is not None:
response = self._is_on_template.render_with_possible_json_value( text = self._is_on_template.async_render_with_possible_json_value(
request.text, 'None') text, 'None')
response = response.lower() text = text.lower()
if response == 'true': if text == 'true':
self._state = True self._state = True
elif response == 'false': elif text == 'false':
self._state = False self._state = False
else: else:
self._state = None self._state = None
else: else:
if request.text == self._body_on.template: if text == self._body_on.template:
self._state = True self._state = True
elif request.text == self._body_off.template: elif text == self._body_off.template:
self._state = False self._state = False
else: else:
self._state = None self._state = None

View File

@ -1,76 +1,83 @@
"""The tests for the REST switch platform.""" """The tests for the REST switch platform."""
import unittest import asyncio
from unittest.mock import patch
import pytest import aiohttp
import requests
from requests.exceptions import Timeout
import requests_mock
import homeassistant.components.switch.rest as rest import homeassistant.components.switch.rest as rest
from homeassistant.bootstrap import setup_component from homeassistant.bootstrap import setup_component
from homeassistant.util.async import run_coroutine_threadsafe
from homeassistant.helpers.template import Template
from tests.common import get_test_home_assistant, assert_setup_component from tests.common import get_test_home_assistant, assert_setup_component
class TestRestSwitchSetup(unittest.TestCase): class TestRestSwitchSetup:
"""Tests for setting up the REST switch platform.""" """Tests for setting up the REST switch platform."""
def setUp(self): def setup_method(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
def tearDown(self): def teardown_method(self):
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
def test_setup_missing_config(self): def test_setup_missing_config(self):
"""Test setup with configuration missing required entries.""" """Test setup with configuration missing required entries."""
self.assertFalse(rest.setup_platform(self.hass, { assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest' 'platform': 'rest'
}, None)) }, None),
self.hass.loop
).result()
def test_setup_missing_schema(self): def test_setup_missing_schema(self):
"""Test setup with resource missing schema.""" """Test setup with resource missing schema."""
self.assertFalse(rest.setup_platform(self.hass, { assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest', 'platform': 'rest',
'resource': 'localhost' 'resource': 'localhost'
}, None)) }, None),
self.hass.loop
).result()
@patch('requests.get', side_effect=requests.exceptions.ConnectionError()) def test_setup_failed_connect(self, aioclient_mock):
def test_setup_failed_connect(self, mock_req):
"""Test setup when connection error occurs.""" """Test setup when connection error occurs."""
self.assertFalse(rest.setup_platform(self.hass, { aioclient_mock.get('http://localhost', exc=aiohttp.errors.ClientError)
assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest', 'platform': 'rest',
'resource': 'http://localhost', 'resource': 'http://localhost',
}, None)) }, None),
self.hass.loop
).result()
@patch('requests.get', side_effect=Timeout()) def test_setup_timeout(self, aioclient_mock):
def test_setup_timeout(self, mock_req):
"""Test setup when connection timeout occurs.""" """Test setup when connection timeout occurs."""
with self.assertRaises(Timeout): aioclient_mock.get('http://localhost', exc=asyncio.TimeoutError())
rest.setup_platform(self.hass, { assert not run_coroutine_threadsafe(
rest.async_setup_platform(self.hass, {
'platform': 'rest', 'platform': 'rest',
'resource': 'http://localhost', 'resource': 'http://localhost',
}, None) }, None),
self.hass.loop
).result()
@requests_mock.Mocker() def test_setup_minimum(self, aioclient_mock):
def test_setup_minimum(self, mock_req):
"""Test setup with minimum configuration.""" """Test setup with minimum configuration."""
mock_req.get('http://localhost', status_code=200) aioclient_mock.get('http://localhost', status=200)
self.assertTrue(setup_component(self.hass, 'switch', { with assert_setup_component(1, 'switch'):
assert setup_component(self.hass, 'switch', {
'switch': { 'switch': {
'platform': 'rest', 'platform': 'rest',
'resource': 'http://localhost' 'resource': 'http://localhost'
} }
})) })
self.assertEqual(1, mock_req.call_count) assert aioclient_mock.call_count == 1
assert_setup_component(1, 'switch')
@requests_mock.Mocker() def test_setup(self, aioclient_mock):
def test_setup(self, mock_req):
"""Test setup with valid configuration.""" """Test setup with valid configuration."""
mock_req.get('localhost', status_code=200) aioclient_mock.get('http://localhost', status=200)
self.assertTrue(setup_component(self.hass, 'switch', { assert setup_component(self.hass, 'switch', {
'switch': { 'switch': {
'platform': 'rest', 'platform': 'rest',
'name': 'foo', 'name': 'foo',
@ -78,111 +85,120 @@ class TestRestSwitchSetup(unittest.TestCase):
'body_on': 'custom on text', 'body_on': 'custom on text',
'body_off': 'custom off text', 'body_off': 'custom off text',
} }
})) })
self.assertEqual(1, mock_req.call_count) assert aioclient_mock.call_count == 1
assert_setup_component(1, 'switch') assert_setup_component(1, 'switch')
@pytest.mark.skip class TestRestSwitch:
class TestRestSwitch(unittest.TestCase):
"""Tests for REST switch platform.""" """Tests for REST switch platform."""
def setUp(self): def setup_method(self):
"""Setup things to be run when tests are started.""" """Setup things to be run when tests are started."""
self.hass = get_test_home_assistant() self.hass = get_test_home_assistant()
self.name = 'foo' self.name = 'foo'
self.resource = 'http://localhost/' self.resource = 'http://localhost/'
self.body_on = 'on' self.body_on = Template('on', self.hass)
self.body_off = 'off' self.body_off = Template('off', self.hass)
self.switch = rest.RestSwitch(self.hass, self.name, self.resource, self.switch = rest.RestSwitch(self.hass, self.name, self.resource,
self.body_on, self.body_off) self.body_on, self.body_off, None, 10)
def tearDown(self): def teardown_method(self):
"""Stop everything that was started.""" """Stop everything that was started."""
self.hass.stop() self.hass.stop()
def test_name(self): def test_name(self):
"""Test the name.""" """Test the name."""
self.assertEqual(self.name, self.switch.name) assert self.name == self.switch.name
def test_is_on_before_update(self): def test_is_on_before_update(self):
"""Test is_on in initial state.""" """Test is_on in initial state."""
self.assertEqual(None, self.switch.is_on) assert self.switch.is_on is None
@requests_mock.Mocker() def test_turn_on_success(self, aioclient_mock):
def test_turn_on_success(self, mock_req):
"""Test turn_on.""" """Test turn_on."""
mock_req.post(self.resource, status_code=200) aioclient_mock.post(self.resource, status=200)
self.switch.turn_on() run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
self.assertEqual(self.body_on, mock_req.last_request.text) assert self.body_on.template == \
self.assertEqual(True, self.switch.is_on) aioclient_mock.mock_calls[-1][2].decode()
assert self.switch.is_on
@requests_mock.Mocker() def test_turn_on_status_not_ok(self, aioclient_mock):
def test_turn_on_status_not_ok(self, mock_req):
"""Test turn_on when error status returned.""" """Test turn_on when error status returned."""
mock_req.post(self.resource, status_code=500) aioclient_mock.post(self.resource, status=500)
self.switch.turn_on() run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
self.assertEqual(self.body_on, mock_req.last_request.text) assert self.body_on.template == \
self.assertEqual(None, self.switch.is_on) aioclient_mock.mock_calls[-1][2].decode()
assert self.switch.is_on is None
@patch('requests.post', side_effect=Timeout()) def test_turn_on_timeout(self, aioclient_mock):
def test_turn_on_timeout(self, mock_req):
"""Test turn_on when timeout occurs.""" """Test turn_on when timeout occurs."""
with self.assertRaises(Timeout): aioclient_mock.post(self.resource, status=500)
self.switch.turn_on() run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
@requests_mock.Mocker() assert self.switch.is_on is None
def test_turn_off_success(self, mock_req):
def test_turn_off_success(self, aioclient_mock):
"""Test turn_off.""" """Test turn_off."""
mock_req.post(self.resource, status_code=200) aioclient_mock.post(self.resource, status=200)
self.switch.turn_off() run_coroutine_threadsafe(
self.switch.async_turn_off(), self.hass.loop).result()
self.assertEqual(self.body_off, mock_req.last_request.text) assert self.body_off.template == \
self.assertEqual(False, self.switch.is_on) aioclient_mock.mock_calls[-1][2].decode()
assert not self.switch.is_on
@requests_mock.Mocker() def test_turn_off_status_not_ok(self, aioclient_mock):
def test_turn_off_status_not_ok(self, mock_req):
"""Test turn_off when error status returned.""" """Test turn_off when error status returned."""
mock_req.post(self.resource, status_code=500) aioclient_mock.post(self.resource, status=500)
self.switch.turn_off() run_coroutine_threadsafe(
self.switch.async_turn_off(), self.hass.loop).result()
self.assertEqual(self.body_off, mock_req.last_request.text) assert self.body_off.template == \
self.assertEqual(None, self.switch.is_on) aioclient_mock.mock_calls[-1][2].decode()
assert self.switch.is_on is None
@patch('requests.post', side_effect=Timeout()) def test_turn_off_timeout(self, aioclient_mock):
def test_turn_off_timeout(self, mock_req):
"""Test turn_off when timeout occurs.""" """Test turn_off when timeout occurs."""
with self.assertRaises(Timeout): aioclient_mock.post(self.resource, exc=asyncio.TimeoutError())
self.switch.turn_on() run_coroutine_threadsafe(
self.switch.async_turn_on(), self.hass.loop).result()
@requests_mock.Mocker() assert self.switch.is_on is None
def test_update_when_on(self, mock_req):
def test_update_when_on(self, aioclient_mock):
"""Test update when switch is on.""" """Test update when switch is on."""
mock_req.get(self.resource, text=self.body_on) aioclient_mock.get(self.resource, text=self.body_on.template)
self.switch.update() run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
self.assertEqual(True, self.switch.is_on) assert self.switch.is_on
@requests_mock.Mocker() def test_update_when_off(self, aioclient_mock):
def test_update_when_off(self, mock_req):
"""Test update when switch is off.""" """Test update when switch is off."""
mock_req.get(self.resource, text=self.body_off) aioclient_mock.get(self.resource, text=self.body_off.template)
self.switch.update() run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
self.assertEqual(False, self.switch.is_on) assert not self.switch.is_on
@requests_mock.Mocker() def test_update_when_unknown(self, aioclient_mock):
def test_update_when_unknown(self, mock_req):
"""Test update when unknown status returned.""" """Test update when unknown status returned."""
mock_req.get(self.resource, text='unknown status') aioclient_mock.get(self.resource, text='unknown status')
self.switch.update() run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
self.assertEqual(None, self.switch.is_on) assert self.switch.is_on is None
@patch('requests.get', side_effect=Timeout()) def test_update_timeout(self, aioclient_mock):
def test_update_timeout(self, mock_req):
"""Test update when timeout occurs.""" """Test update when timeout occurs."""
with self.assertRaises(Timeout): aioclient_mock.get(self.resource, exc=asyncio.TimeoutError())
self.switch.update() run_coroutine_threadsafe(
self.switch.async_update(), self.hass.loop).result()
assert self.switch.is_on is None

View File

@ -20,6 +20,7 @@ class AiohttpClientMocker:
auth=None, auth=None,
status=200, status=200,
text=None, text=None,
data=None,
content=None, content=None,
json=None, json=None,
params=None, params=None,
@ -66,12 +67,12 @@ class AiohttpClientMocker:
return len(self.mock_calls) return len(self.mock_calls)
@asyncio.coroutine @asyncio.coroutine
def match_request(self, method, url, *, auth=None, params=None, def match_request(self, method, url, *, data=None, auth=None, params=None,
headers=None): # pylint: disable=unused-variable headers=None): # pylint: disable=unused-variable
"""Match a request against pre-registered requests.""" """Match a request against pre-registered requests."""
for response in self._mocks: for response in self._mocks:
if response.match_request(method, url, params): if response.match_request(method, url, params):
self.mock_calls.append((method, url)) self.mock_calls.append((method, url, data))
if self.exc: if self.exc:
raise self.exc raise self.exc