mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 16:57:10 +00:00
Fix tests (#7659)
* Remove global hass * Http.auth test no longer spin up server * Remove server usage from http.ban test * Remove setupModule from test device_sun_light_trigger * Update common.py
This commit is contained in:
parent
5aa72562a7
commit
d369d70ca5
@ -150,14 +150,14 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||||||
scanner = yield from platform.async_get_scanner(
|
scanner = yield from platform.async_get_scanner(
|
||||||
hass, {DOMAIN: p_config})
|
hass, {DOMAIN: p_config})
|
||||||
elif hasattr(platform, 'get_scanner'):
|
elif hasattr(platform, 'get_scanner'):
|
||||||
scanner = yield from hass.loop.run_in_executor(
|
scanner = yield from hass.async_add_job(
|
||||||
None, platform.get_scanner, hass, {DOMAIN: p_config})
|
platform.get_scanner, hass, {DOMAIN: p_config})
|
||||||
elif hasattr(platform, 'async_setup_scanner'):
|
elif hasattr(platform, 'async_setup_scanner'):
|
||||||
setup = yield from platform.async_setup_scanner(
|
setup = yield from platform.async_setup_scanner(
|
||||||
hass, p_config, tracker.async_see, disc_info)
|
hass, p_config, tracker.async_see, disc_info)
|
||||||
elif hasattr(platform, 'setup_scanner'):
|
elif hasattr(platform, 'setup_scanner'):
|
||||||
setup = yield from hass.loop.run_in_executor(
|
setup = yield from hass.async_add_job(
|
||||||
None, platform.setup_scanner, hass, p_config, tracker.see,
|
platform.setup_scanner, hass, p_config, tracker.see,
|
||||||
disc_info)
|
disc_info)
|
||||||
else:
|
else:
|
||||||
raise HomeAssistantError("Invalid device_tracker platform.")
|
raise HomeAssistantError("Invalid device_tracker platform.")
|
||||||
@ -209,8 +209,8 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||||||
ATTR_GPS, ATTR_GPS_ACCURACY, ATTR_BATTERY, ATTR_ATTRIBUTES)}
|
ATTR_GPS, ATTR_GPS_ACCURACY, ATTR_BATTERY, ATTR_ATTRIBUTES)}
|
||||||
yield from tracker.async_see(**args)
|
yield from tracker.async_see(**args)
|
||||||
|
|
||||||
descriptions = yield from hass.loop.run_in_executor(
|
descriptions = yield from hass.async_add_job(
|
||||||
None, load_yaml_config_file,
|
load_yaml_config_file,
|
||||||
os.path.join(os.path.dirname(__file__), 'services.yaml')
|
os.path.join(os.path.dirname(__file__), 'services.yaml')
|
||||||
)
|
)
|
||||||
hass.services.async_register(
|
hass.services.async_register(
|
||||||
@ -322,8 +322,8 @@ class DeviceTracker(object):
|
|||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
with (yield from self._is_updating):
|
with (yield from self._is_updating):
|
||||||
yield from self.hass.loop.run_in_executor(
|
yield from self.hass.async_add_job(
|
||||||
None, update_config, self.hass.config.path(YAML_DEVICES),
|
update_config, self.hass.config.path(YAML_DEVICES),
|
||||||
dev_id, device)
|
dev_id, device)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@ -608,7 +608,7 @@ class DeviceScanner(object):
|
|||||||
|
|
||||||
This method must be run in the event loop and returns a coroutine.
|
This method must be run in the event loop and returns a coroutine.
|
||||||
"""
|
"""
|
||||||
return self.hass.loop.run_in_executor(None, self.scan_devices)
|
return self.hass.async_add_job(self.scan_devices)
|
||||||
|
|
||||||
def get_device_name(self, mac: str) -> str:
|
def get_device_name(self, mac: str) -> str:
|
||||||
"""Get device name from mac."""
|
"""Get device name from mac."""
|
||||||
@ -619,7 +619,7 @@ class DeviceScanner(object):
|
|||||||
|
|
||||||
This method must be run in the event loop and returns a coroutine.
|
This method must be run in the event loop and returns a coroutine.
|
||||||
"""
|
"""
|
||||||
return self.hass.loop.run_in_executor(None, self.get_device_name, mac)
|
return self.hass.async_add_job(self.get_device_name, mac)
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
|
def load_config(path: str, hass: HomeAssistantType, consider_home: timedelta):
|
||||||
@ -650,8 +650,8 @@ def async_load_config(path: str, hass: HomeAssistantType,
|
|||||||
try:
|
try:
|
||||||
result = []
|
result = []
|
||||||
try:
|
try:
|
||||||
devices = yield from hass.loop.run_in_executor(
|
devices = yield from hass.async_add_job(
|
||||||
None, load_yaml_config_file, path)
|
load_yaml_config_file, path)
|
||||||
except HomeAssistantError as err:
|
except HomeAssistantError as err:
|
||||||
_LOGGER.error("Unable to load %s: %s", path, str(err))
|
_LOGGER.error("Unable to load %s: %s", path, str(err))
|
||||||
return []
|
return []
|
||||||
|
@ -1,26 +1,19 @@
|
|||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
import logging
|
import asyncio
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import requests
|
import pytest
|
||||||
|
|
||||||
from homeassistant import setup, const
|
from homeassistant import const
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.components.http as http
|
import homeassistant.components.http as http
|
||||||
from homeassistant.components.http.const import (
|
from homeassistant.components.http.const import (
|
||||||
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
||||||
|
|
||||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
API_PASSWORD = 'test1234'
|
||||||
SERVER_PORT = get_test_instance_port()
|
|
||||||
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
|
||||||
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
|
||||||
HA_HEADERS = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
|
||||||
}
|
|
||||||
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
||||||
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
||||||
'FD01:DB8::1']
|
'FD01:DB8::1']
|
||||||
@ -28,142 +21,131 @@ TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
|||||||
'2001:DB8:ABCD::1']
|
'2001:DB8:ABCD::1']
|
||||||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||||
|
|
||||||
hass = None
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
def _url(path=''):
|
def mock_api_client(hass, test_client):
|
||||||
"""Helper method to generate URLs."""
|
"""Start the Hass HTTP component."""
|
||||||
return HTTP_BASE_URL + path
|
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
|
||||||
|
'http': {
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def setUpModule():
|
|
||||||
"""Initialize a Home Assistant server."""
|
|
||||||
global hass
|
|
||||||
|
|
||||||
hass = get_test_home_assistant()
|
|
||||||
|
|
||||||
setup.setup_component(
|
|
||||||
hass, http.DOMAIN, {
|
|
||||||
http.DOMAIN: {
|
|
||||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||||
http.CONF_SERVER_PORT: SERVER_PORT,
|
|
||||||
}
|
}
|
||||||
}
|
}))
|
||||||
)
|
return hass.loop.run_until_complete(test_client(hass.http.app))
|
||||||
|
|
||||||
setup.setup_component(hass, 'api')
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_trusted_networks(hass, mock_api_client):
|
||||||
|
"""Mock trusted networks."""
|
||||||
hass.http.app[KEY_TRUSTED_NETWORKS] = [
|
hass.http.app[KEY_TRUSTED_NETWORKS] = [
|
||||||
ip_network(trusted_network)
|
ip_network(trusted_network)
|
||||||
for trusted_network in TRUSTED_NETWORKS]
|
for trusted_network in TRUSTED_NETWORKS]
|
||||||
|
|
||||||
hass.start()
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
# pylint: disable=invalid-name
|
def test_access_denied_without_password(mock_api_client):
|
||||||
def tearDownModule():
|
|
||||||
"""Stop the Home Assistant server."""
|
|
||||||
hass.stop()
|
|
||||||
|
|
||||||
|
|
||||||
class TestHttp:
|
|
||||||
"""Test HTTP component."""
|
|
||||||
|
|
||||||
def test_access_denied_without_password(self):
|
|
||||||
"""Test access without password."""
|
"""Test access without password."""
|
||||||
req = requests.get(_url(const.URL_API))
|
resp = yield from mock_api_client.get(const.URL_API)
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
assert req.status_code == 401
|
|
||||||
|
|
||||||
def test_access_denied_with_wrong_password_in_header(self):
|
@asyncio.coroutine
|
||||||
|
def test_access_denied_with_wrong_password_in_header(mock_api_client):
|
||||||
"""Test access with wrong password."""
|
"""Test access with wrong password."""
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(const.URL_API, headers={
|
||||||
_url(const.URL_API),
|
const.HTTP_HEADER_HA_AUTH: 'wrongpassword'
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
|
})
|
||||||
|
assert resp.status == 401
|
||||||
|
|
||||||
assert req.status_code == 401
|
|
||||||
|
|
||||||
def test_access_denied_with_x_forwarded_for(self, caplog):
|
@asyncio.coroutine
|
||||||
|
def test_access_denied_with_x_forwarded_for(hass, mock_api_client,
|
||||||
|
mock_trusted_networks):
|
||||||
"""Test access denied through the X-Forwarded-For http header."""
|
"""Test access denied through the X-Forwarded-For http header."""
|
||||||
hass.http.use_x_forwarded_for = True
|
hass.http.use_x_forwarded_for = True
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||||
req = requests.get(_url(const.URL_API), headers={
|
resp = yield from mock_api_client.get(const.URL_API, headers={
|
||||||
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
||||||
|
|
||||||
assert req.status_code == 401, \
|
assert resp.status == 401, \
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
"{} shouldn't be trusted".format(remote_addr)
|
||||||
|
|
||||||
def test_access_denied_with_untrusted_ip(self, caplog):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_denied_with_untrusted_ip(mock_api_client,
|
||||||
|
mock_trusted_networks):
|
||||||
"""Test access with an untrusted ip address."""
|
"""Test access with an untrusted ip address."""
|
||||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||||
with patch('homeassistant.components.http.'
|
with patch('homeassistant.components.http.'
|
||||||
'util.get_real_ip',
|
'util.get_real_ip',
|
||||||
return_value=ip_address(remote_addr)):
|
return_value=ip_address(remote_addr)):
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API), params={'api_password': ''})
|
const.URL_API, params={'api_password': ''})
|
||||||
|
|
||||||
assert req.status_code == 401, \
|
assert resp.status == 401, \
|
||||||
"{} shouldn't be trusted".format(remote_addr)
|
"{} shouldn't be trusted".format(remote_addr)
|
||||||
|
|
||||||
def test_access_with_password_in_header(self, caplog):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_with_password_in_header(mock_api_client, caplog):
|
||||||
"""Test access with password in URL."""
|
"""Test access with password in URL."""
|
||||||
# Hide logging from requests package that we use to test logging
|
# Hide logging from requests package that we use to test logging
|
||||||
caplog.set_level(
|
req = yield from mock_api_client.get(
|
||||||
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
|
const.URL_API, headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||||
|
|
||||||
req = requests.get(
|
assert req.status == 200
|
||||||
_url(const.URL_API),
|
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
|
|
||||||
logs = caplog.text
|
logs = caplog.text
|
||||||
|
|
||||||
assert const.URL_API in logs
|
assert const.URL_API in logs
|
||||||
assert API_PASSWORD not in logs
|
assert API_PASSWORD not in logs
|
||||||
|
|
||||||
def test_access_denied_with_wrong_password_in_url(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_denied_with_wrong_password_in_url(mock_api_client):
|
||||||
"""Test access with wrong password."""
|
"""Test access with wrong password."""
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API), params={'api_password': 'wrongpassword'})
|
const.URL_API, params={'api_password': 'wrongpassword'})
|
||||||
|
|
||||||
assert req.status_code == 401
|
assert resp.status == 401
|
||||||
|
|
||||||
def test_access_with_password_in_url(self, caplog):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_with_password_in_url(mock_api_client, caplog):
|
||||||
"""Test access with password in URL."""
|
"""Test access with password in URL."""
|
||||||
# Hide logging from requests package that we use to test logging
|
req = yield from mock_api_client.get(
|
||||||
caplog.set_level(
|
const.URL_API, params={'api_password': API_PASSWORD})
|
||||||
logging.WARNING, logger='requests.packages.urllib3.connectionpool')
|
|
||||||
|
|
||||||
req = requests.get(
|
assert req.status == 200
|
||||||
_url(const.URL_API), params={'api_password': API_PASSWORD})
|
|
||||||
|
|
||||||
assert req.status_code == 200
|
|
||||||
|
|
||||||
logs = caplog.text
|
logs = caplog.text
|
||||||
|
|
||||||
assert const.URL_API in logs
|
assert const.URL_API in logs
|
||||||
assert API_PASSWORD not in logs
|
assert API_PASSWORD not in logs
|
||||||
|
|
||||||
def test_access_granted_with_x_forwarded_for(self, caplog):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_granted_with_x_forwarded_for(hass, mock_api_client, caplog,
|
||||||
|
mock_trusted_networks):
|
||||||
"""Test access denied through the X-Forwarded-For http header."""
|
"""Test access denied through the X-Forwarded-For http header."""
|
||||||
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
|
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
for remote_addr in TRUSTED_ADDRESSES:
|
||||||
req = requests.get(_url(const.URL_API), headers={
|
resp = yield from mock_api_client.get(const.URL_API, headers={
|
||||||
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
||||||
|
|
||||||
assert req.status_code == 200, \
|
assert resp.status == 200, \
|
||||||
"{} should be trusted".format(remote_addr)
|
"{} should be trusted".format(remote_addr)
|
||||||
|
|
||||||
def test_access_granted_with_trusted_ip(self, caplog):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_granted_with_trusted_ip(mock_api_client, caplog,
|
||||||
|
mock_trusted_networks):
|
||||||
"""Test access with trusted addresses."""
|
"""Test access with trusted addresses."""
|
||||||
for remote_addr in TRUSTED_ADDRESSES:
|
for remote_addr in TRUSTED_ADDRESSES:
|
||||||
with patch('homeassistant.components.http.'
|
with patch('homeassistant.components.http.'
|
||||||
'auth.get_real_ip',
|
'auth.get_real_ip',
|
||||||
return_value=ip_address(remote_addr)):
|
return_value=ip_address(remote_addr)):
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API), params={'api_password': ''})
|
const.URL_API, params={'api_password': ''})
|
||||||
|
|
||||||
assert req.status_code == 200, \
|
assert resp.status == 200, \
|
||||||
'{} should be trusted'.format(remote_addr)
|
'{} should be trusted'.format(remote_addr)
|
||||||
|
@ -1,117 +1,91 @@
|
|||||||
"""The tests for the Home Assistant HTTP component."""
|
"""The tests for the Home Assistant HTTP component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
import asyncio
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from unittest.mock import patch, mock_open
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
import requests
|
import pytest
|
||||||
|
|
||||||
from homeassistant import setup, const
|
from homeassistant import const
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
import homeassistant.components.http as http
|
import homeassistant.components.http as http
|
||||||
from homeassistant.components.http.const import (
|
from homeassistant.components.http.const import (
|
||||||
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
|
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
|
||||||
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
|
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
|
||||||
|
|
||||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
|
||||||
|
|
||||||
API_PASSWORD = 'test1234'
|
API_PASSWORD = 'test1234'
|
||||||
SERVER_PORT = get_test_instance_port()
|
|
||||||
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
|
||||||
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
|
||||||
HA_HEADERS = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
|
||||||
}
|
|
||||||
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
||||||
|
|
||||||
hass = None
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
def _url(path=''):
|
def mock_api_client(hass, test_client):
|
||||||
"""Helper method to generate URLs."""
|
"""Start the Hass HTTP component."""
|
||||||
return HTTP_BASE_URL + path
|
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
|
||||||
|
'http': {
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def setUpModule():
|
|
||||||
"""Initialize a Home Assistant server."""
|
|
||||||
global hass
|
|
||||||
|
|
||||||
hass = get_test_home_assistant()
|
|
||||||
|
|
||||||
setup.setup_component(
|
|
||||||
hass, http.DOMAIN, {
|
|
||||||
http.DOMAIN: {
|
|
||||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||||
http.CONF_SERVER_PORT: SERVER_PORT,
|
|
||||||
}
|
}
|
||||||
}
|
}))
|
||||||
)
|
|
||||||
|
|
||||||
setup.setup_component(hass, 'api')
|
|
||||||
|
|
||||||
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
|
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
|
||||||
in BANNED_IPS]
|
in BANNED_IPS]
|
||||||
hass.start()
|
return hass.loop.run_until_complete(test_client(hass.http.app))
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
@asyncio.coroutine
|
||||||
def tearDownModule():
|
def test_access_from_banned_ip(hass, mock_api_client):
|
||||||
"""Stop the Home Assistant server."""
|
|
||||||
hass.stop()
|
|
||||||
|
|
||||||
|
|
||||||
class TestHttp:
|
|
||||||
"""Test HTTP component."""
|
|
||||||
|
|
||||||
def test_access_from_banned_ip(self):
|
|
||||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||||
hass.http.app[KEY_BANS_ENABLED] = True
|
hass.http.app[KEY_BANS_ENABLED] = True
|
||||||
for remote_addr in BANNED_IPS:
|
for remote_addr in BANNED_IPS:
|
||||||
with patch('homeassistant.components.http.'
|
with patch('homeassistant.components.http.'
|
||||||
'ban.get_real_ip',
|
'ban.get_real_ip',
|
||||||
return_value=ip_address(remote_addr)):
|
return_value=ip_address(remote_addr)):
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API))
|
const.URL_API)
|
||||||
assert req.status_code == 403
|
assert resp.status == 403
|
||||||
|
|
||||||
def test_access_from_banned_ip_when_ban_is_off(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_access_from_banned_ip_when_ban_is_off(hass, mock_api_client):
|
||||||
"""Test accessing to server from banned IP when feature is off."""
|
"""Test accessing to server from banned IP when feature is off."""
|
||||||
hass.http.app[KEY_BANS_ENABLED] = False
|
hass.http.app[KEY_BANS_ENABLED] = False
|
||||||
for remote_addr in BANNED_IPS:
|
for remote_addr in BANNED_IPS:
|
||||||
with patch('homeassistant.components.http.'
|
with patch('homeassistant.components.http.'
|
||||||
'ban.get_real_ip',
|
'ban.get_real_ip',
|
||||||
return_value=ip_address(remote_addr)):
|
return_value=ip_address(remote_addr)):
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API),
|
const.URL_API,
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||||
assert req.status_code == 200
|
assert resp.status == 200
|
||||||
|
|
||||||
def test_ip_bans_file_creation(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_ip_bans_file_creation(hass, mock_api_client):
|
||||||
"""Testing if banned IP file created."""
|
"""Testing if banned IP file created."""
|
||||||
hass.http.app[KEY_BANS_ENABLED] = True
|
hass.http.app[KEY_BANS_ENABLED] = True
|
||||||
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
|
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
|
||||||
|
|
||||||
m = mock_open()
|
m = mock_open()
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def call_server():
|
def call_server():
|
||||||
with patch('homeassistant.components.http.'
|
with patch('homeassistant.components.http.'
|
||||||
'ban.get_real_ip',
|
'ban.get_real_ip',
|
||||||
return_value=ip_address("200.201.202.204")):
|
return_value=ip_address("200.201.202.204")):
|
||||||
return requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API),
|
const.URL_API,
|
||||||
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
||||||
|
return resp
|
||||||
|
|
||||||
with patch('homeassistant.components.http.ban.open', m, create=True):
|
with patch('homeassistant.components.http.ban.open', m, create=True):
|
||||||
req = call_server()
|
resp = yield from call_server()
|
||||||
assert req.status_code == 401
|
assert resp.status == 401
|
||||||
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||||
assert m.call_count == 0
|
assert m.call_count == 0
|
||||||
|
|
||||||
req = call_server()
|
resp = yield from call_server()
|
||||||
assert req.status_code == 401
|
assert resp.status == 401
|
||||||
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||||
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
||||||
|
|
||||||
req = call_server()
|
resp = yield from call_server()
|
||||||
assert req.status_code == 403
|
assert resp.status == 403
|
||||||
assert m.call_count == 1
|
assert m.call_count == 1
|
||||||
|
@ -1,182 +1,156 @@
|
|||||||
"""The tests for the Home Assistant API component."""
|
"""The tests for the Home Assistant API component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
from contextlib import closing
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import unittest
|
|
||||||
|
|
||||||
import requests
|
import pytest
|
||||||
|
|
||||||
from homeassistant import setup, const
|
from homeassistant import const
|
||||||
import homeassistant.core as ha
|
import homeassistant.core as ha
|
||||||
import homeassistant.components.http as http
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
|
||||||
|
|
||||||
API_PASSWORD = "test1234"
|
|
||||||
SERVER_PORT = get_test_instance_port()
|
|
||||||
HTTP_BASE_URL = "http://127.0.0.1:{}".format(SERVER_PORT)
|
|
||||||
HA_HEADERS = {
|
|
||||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
|
||||||
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
|
||||||
}
|
|
||||||
|
|
||||||
hass = None
|
|
||||||
|
|
||||||
|
|
||||||
def _url(path=""):
|
@pytest.fixture
|
||||||
"""Helper method to generate URLs."""
|
def mock_api_client(hass, test_client):
|
||||||
return HTTP_BASE_URL + path
|
"""Start the Hass HTTP component."""
|
||||||
|
hass.loop.run_until_complete(async_setup_component(hass, 'api', {}))
|
||||||
|
return hass.loop.run_until_complete(test_client(hass.http.app))
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
@asyncio.coroutine
|
||||||
def setUpModule():
|
def test_api_list_state_entities(hass, mock_api_client):
|
||||||
"""Initialize a Home Assistant server."""
|
|
||||||
global hass
|
|
||||||
|
|
||||||
hass = get_test_home_assistant()
|
|
||||||
|
|
||||||
hass.bus.listen('test_event', lambda _: _)
|
|
||||||
hass.states.set('test.test', 'a_state')
|
|
||||||
|
|
||||||
setup.setup_component(
|
|
||||||
hass, http.DOMAIN,
|
|
||||||
{http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD,
|
|
||||||
http.CONF_SERVER_PORT: SERVER_PORT}})
|
|
||||||
|
|
||||||
setup.setup_component(hass, 'api')
|
|
||||||
|
|
||||||
hass.start()
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def tearDownModule():
|
|
||||||
"""Stop the Home Assistant server."""
|
|
||||||
hass.stop()
|
|
||||||
|
|
||||||
|
|
||||||
class TestAPI(unittest.TestCase):
|
|
||||||
"""Test the API."""
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""Stop everything that was started."""
|
|
||||||
hass.block_till_done()
|
|
||||||
|
|
||||||
def test_api_list_state_entities(self):
|
|
||||||
"""Test if the debug interface allows us to list state entities."""
|
"""Test if the debug interface allows us to list state entities."""
|
||||||
req = requests.get(_url(const.URL_API_STATES),
|
hass.states.async_set('test.entity', 'hello')
|
||||||
headers=HA_HEADERS)
|
resp = yield from mock_api_client.get(const.URL_API_STATES)
|
||||||
|
assert resp.status == 200
|
||||||
|
json = yield from resp.json()
|
||||||
|
|
||||||
remote_data = [ha.State.from_dict(item) for item in req.json()]
|
remote_data = [ha.State.from_dict(item) for item in json]
|
||||||
|
assert remote_data == hass.states.async_all()
|
||||||
|
|
||||||
self.assertEqual(hass.states.all(), remote_data)
|
|
||||||
|
|
||||||
def test_api_get_state(self):
|
@asyncio.coroutine
|
||||||
|
def test_api_get_state(hass, mock_api_client):
|
||||||
"""Test if the debug interface allows us to get a state."""
|
"""Test if the debug interface allows us to get a state."""
|
||||||
req = requests.get(
|
hass.states.async_set('hello.world', 'nice', {
|
||||||
_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
'attr': 1,
|
||||||
headers=HA_HEADERS)
|
})
|
||||||
|
resp = yield from mock_api_client.get(
|
||||||
|
const.URL_API_STATES_ENTITY.format("hello.world"))
|
||||||
|
assert resp.status == 200
|
||||||
|
json = yield from resp.json()
|
||||||
|
|
||||||
data = ha.State.from_dict(req.json())
|
data = ha.State.from_dict(json)
|
||||||
|
|
||||||
state = hass.states.get("test.test")
|
state = hass.states.get("hello.world")
|
||||||
|
|
||||||
self.assertEqual(state.state, data.state)
|
assert data.state == state.state
|
||||||
self.assertEqual(state.last_changed, data.last_changed)
|
assert data.last_changed == state.last_changed
|
||||||
self.assertEqual(state.attributes, data.attributes)
|
assert data.attributes == state.attributes
|
||||||
|
|
||||||
def test_api_get_non_existing_state(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_get_non_existing_state(hass, mock_api_client):
|
||||||
"""Test if the debug interface allows us to get a state."""
|
"""Test if the debug interface allows us to get a state."""
|
||||||
req = requests.get(
|
resp = yield from mock_api_client.get(
|
||||||
_url(const.URL_API_STATES_ENTITY.format("does_not_exist")),
|
const.URL_API_STATES_ENTITY.format("does_not_exist"))
|
||||||
headers=HA_HEADERS)
|
assert resp.status == 404
|
||||||
|
|
||||||
self.assertEqual(404, req.status_code)
|
|
||||||
|
|
||||||
def test_api_state_change(self):
|
@asyncio.coroutine
|
||||||
|
def test_api_state_change(hass, mock_api_client):
|
||||||
"""Test if we can change the state of an entity that exists."""
|
"""Test if we can change the state of an entity that exists."""
|
||||||
hass.states.set("test.test", "not_to_be_set")
|
hass.states.async_set("test.test", "not_to_be_set")
|
||||||
|
|
||||||
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
yield from mock_api_client.post(
|
||||||
data=json.dumps({"state": "debug_state_change2"}),
|
const.URL_API_STATES_ENTITY.format("test.test"),
|
||||||
headers=HA_HEADERS)
|
json={"state": "debug_state_change2"})
|
||||||
|
|
||||||
self.assertEqual("debug_state_change2",
|
assert hass.states.get("test.test").state == "debug_state_change2"
|
||||||
hass.states.get("test.test").state)
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def test_api_state_change_of_non_existing_entity(self):
|
# pylint: disable=invalid-name
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_state_change_of_non_existing_entity(hass, mock_api_client):
|
||||||
"""Test if changing a state of a non existing entity is possible."""
|
"""Test if changing a state of a non existing entity is possible."""
|
||||||
new_state = "debug_state_change"
|
new_state = "debug_state_change"
|
||||||
|
|
||||||
req = requests.post(
|
resp = yield from mock_api_client.post(
|
||||||
_url(const.URL_API_STATES_ENTITY.format(
|
const.URL_API_STATES_ENTITY.format("test_entity.that_does_not_exist"),
|
||||||
"test_entity.that_does_not_exist")),
|
json={'state': new_state})
|
||||||
data=json.dumps({'state': new_state}),
|
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
cur_state = (hass.states.
|
assert resp.status == 201
|
||||||
get("test_entity.that_does_not_exist").state)
|
|
||||||
|
|
||||||
self.assertEqual(201, req.status_code)
|
assert hass.states.get("test_entity.that_does_not_exist").state == \
|
||||||
self.assertEqual(cur_state, new_state)
|
new_state
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def test_api_state_change_with_bad_data(self):
|
# pylint: disable=invalid-name
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_state_change_with_bad_data(hass, mock_api_client):
|
||||||
"""Test if API sends appropriate error if we omit state."""
|
"""Test if API sends appropriate error if we omit state."""
|
||||||
req = requests.post(
|
resp = yield from mock_api_client.post(
|
||||||
_url(const.URL_API_STATES_ENTITY.format(
|
const.URL_API_STATES_ENTITY.format("test_entity.that_does_not_exist"),
|
||||||
"test_entity.that_does_not_exist")),
|
json={})
|
||||||
data=json.dumps({}),
|
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
self.assertEqual(400, req.status_code)
|
assert resp.status == 400
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def test_api_state_change_push(self):
|
# pylint: disable=invalid-name
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_state_change_push(hass, mock_api_client):
|
||||||
"""Test if we can push a change the state of an entity."""
|
"""Test if we can push a change the state of an entity."""
|
||||||
hass.states.set("test.test", "not_to_be_set")
|
hass.states.async_set("test.test", "not_to_be_set")
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
hass.bus.listen(const.EVENT_STATE_CHANGED,
|
|
||||||
lambda ev: events.append(ev))
|
|
||||||
|
|
||||||
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
@ha.callback
|
||||||
data=json.dumps({"state": "not_to_be_set"}),
|
def event_listener(event):
|
||||||
headers=HA_HEADERS)
|
"""Track events."""
|
||||||
hass.block_till_done()
|
events.append(event)
|
||||||
self.assertEqual(0, len(events))
|
|
||||||
|
|
||||||
requests.post(_url(const.URL_API_STATES_ENTITY.format("test.test")),
|
hass.bus.async_listen(const.EVENT_STATE_CHANGED, event_listener)
|
||||||
data=json.dumps({"state": "not_to_be_set",
|
|
||||||
"force_update": True}),
|
|
||||||
headers=HA_HEADERS)
|
|
||||||
hass.block_till_done()
|
|
||||||
self.assertEqual(1, len(events))
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
yield from mock_api_client.post(
|
||||||
def test_api_fire_event_with_no_data(self):
|
const.URL_API_STATES_ENTITY.format("test.test"),
|
||||||
|
json={"state": "not_to_be_set"})
|
||||||
|
yield from hass.async_block_till_done()
|
||||||
|
assert len(events) == 0
|
||||||
|
|
||||||
|
yield from mock_api_client.post(
|
||||||
|
const.URL_API_STATES_ENTITY.format("test.test"),
|
||||||
|
json={"state": "not_to_be_set", "force_update": True})
|
||||||
|
yield from hass.async_block_till_done()
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_fire_event_with_no_data(hass, mock_api_client):
|
||||||
"""Test if the API allows us to fire an event."""
|
"""Test if the API allows us to fire an event."""
|
||||||
test_value = []
|
test_value = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
def listener(event):
|
def listener(event):
|
||||||
"""Helper method that will verify our event got called."""
|
"""Helper method that will verify our event got called."""
|
||||||
test_value.append(1)
|
test_value.append(1)
|
||||||
|
|
||||||
hass.bus.listen_once("test.event_no_data", listener)
|
hass.bus.async_listen_once("test.event_no_data", listener)
|
||||||
|
|
||||||
requests.post(
|
yield from mock_api_client.post(
|
||||||
_url(const.URL_API_EVENTS_EVENT.format("test.event_no_data")),
|
const.URL_API_EVENTS_EVENT.format("test.event_no_data"))
|
||||||
headers=HA_HEADERS)
|
yield from hass.async_block_till_done()
|
||||||
|
|
||||||
hass.block_till_done()
|
assert len(test_value) == 1
|
||||||
|
|
||||||
self.assertEqual(1, len(test_value))
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
def test_api_fire_event_with_data(self):
|
@asyncio.coroutine
|
||||||
|
def test_api_fire_event_with_data(hass, mock_api_client):
|
||||||
"""Test if the API allows us to fire an event."""
|
"""Test if the API allows us to fire an event."""
|
||||||
test_value = []
|
test_value = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
def listener(event):
|
def listener(event):
|
||||||
"""Helper method that will verify that our event got called.
|
"""Helper method that will verify that our event got called.
|
||||||
|
|
||||||
@ -185,91 +159,98 @@ class TestAPI(unittest.TestCase):
|
|||||||
if "test" in event.data:
|
if "test" in event.data:
|
||||||
test_value.append(1)
|
test_value.append(1)
|
||||||
|
|
||||||
hass.bus.listen_once("test_event_with_data", listener)
|
hass.bus.async_listen_once("test_event_with_data", listener)
|
||||||
|
|
||||||
requests.post(
|
yield from mock_api_client.post(
|
||||||
_url(const.URL_API_EVENTS_EVENT.format("test_event_with_data")),
|
const.URL_API_EVENTS_EVENT.format("test_event_with_data"),
|
||||||
data=json.dumps({"test": 1}),
|
json={"test": 1})
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
hass.block_till_done()
|
yield from hass.async_block_till_done()
|
||||||
|
|
||||||
self.assertEqual(1, len(test_value))
|
assert len(test_value) == 1
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def test_api_fire_event_with_invalid_json(self):
|
# pylint: disable=invalid-name
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_fire_event_with_invalid_json(hass, mock_api_client):
|
||||||
"""Test if the API allows us to fire an event."""
|
"""Test if the API allows us to fire an event."""
|
||||||
test_value = []
|
test_value = []
|
||||||
|
|
||||||
|
@ha.callback
|
||||||
def listener(event):
|
def listener(event):
|
||||||
"""Helper method that will verify our event got called."""
|
"""Helper method that will verify our event got called."""
|
||||||
test_value.append(1)
|
test_value.append(1)
|
||||||
|
|
||||||
hass.bus.listen_once("test_event_bad_data", listener)
|
hass.bus.async_listen_once("test_event_bad_data", listener)
|
||||||
|
|
||||||
req = requests.post(
|
resp = yield from mock_api_client.post(
|
||||||
_url(const.URL_API_EVENTS_EVENT.format("test_event_bad_data")),
|
const.URL_API_EVENTS_EVENT.format("test_event_bad_data"),
|
||||||
data=json.dumps('not an object'),
|
data=json.dumps('not an object'))
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
hass.block_till_done()
|
yield from hass.async_block_till_done()
|
||||||
|
|
||||||
self.assertEqual(400, req.status_code)
|
assert resp.status == 400
|
||||||
self.assertEqual(0, len(test_value))
|
assert len(test_value) == 0
|
||||||
|
|
||||||
# Try now with valid but unusable JSON
|
# Try now with valid but unusable JSON
|
||||||
req = requests.post(
|
resp = yield from mock_api_client.post(
|
||||||
_url(const.URL_API_EVENTS_EVENT.format("test_event_bad_data")),
|
const.URL_API_EVENTS_EVENT.format("test_event_bad_data"),
|
||||||
data=json.dumps([1, 2, 3]),
|
data=json.dumps([1, 2, 3]))
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
hass.block_till_done()
|
yield from hass.async_block_till_done()
|
||||||
|
|
||||||
self.assertEqual(400, req.status_code)
|
assert resp.status == 400
|
||||||
self.assertEqual(0, len(test_value))
|
assert len(test_value) == 0
|
||||||
|
|
||||||
def test_api_get_config(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_get_config(hass, mock_api_client):
|
||||||
"""Test the return of the configuration."""
|
"""Test the return of the configuration."""
|
||||||
req = requests.get(_url(const.URL_API_CONFIG),
|
resp = yield from mock_api_client.get(const.URL_API_CONFIG)
|
||||||
headers=HA_HEADERS)
|
result = yield from resp.json()
|
||||||
result = req.json()
|
|
||||||
if 'components' in result:
|
if 'components' in result:
|
||||||
result['components'] = set(result['components'])
|
result['components'] = set(result['components'])
|
||||||
|
|
||||||
self.assertEqual(hass.config.as_dict(), result)
|
assert hass.config.as_dict() == result
|
||||||
|
|
||||||
def test_api_get_components(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_get_components(hass, mock_api_client):
|
||||||
"""Test the return of the components."""
|
"""Test the return of the components."""
|
||||||
req = requests.get(_url(const.URL_API_COMPONENTS),
|
resp = yield from mock_api_client.get(const.URL_API_COMPONENTS)
|
||||||
headers=HA_HEADERS)
|
result = yield from resp.json()
|
||||||
self.assertEqual(hass.config.components, set(req.json()))
|
assert set(result) == hass.config.components
|
||||||
|
|
||||||
def test_api_get_event_listeners(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_get_event_listeners(hass, mock_api_client):
|
||||||
"""Test if we can get the list of events being listened for."""
|
"""Test if we can get the list of events being listened for."""
|
||||||
req = requests.get(_url(const.URL_API_EVENTS),
|
resp = yield from mock_api_client.get(const.URL_API_EVENTS)
|
||||||
headers=HA_HEADERS)
|
data = yield from resp.json()
|
||||||
|
|
||||||
local = hass.bus.listeners
|
local = hass.bus.async_listeners()
|
||||||
|
|
||||||
for event in req.json():
|
for event in data:
|
||||||
self.assertEqual(event["listener_count"],
|
assert local.pop(event["event"]) == event["listener_count"]
|
||||||
local.pop(event["event"]))
|
|
||||||
|
|
||||||
self.assertEqual(0, len(local))
|
assert len(local) == 0
|
||||||
|
|
||||||
def test_api_get_services(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_get_services(hass, mock_api_client):
|
||||||
"""Test if we can get a dict describing current services."""
|
"""Test if we can get a dict describing current services."""
|
||||||
req = requests.get(_url(const.URL_API_SERVICES),
|
resp = yield from mock_api_client.get(const.URL_API_SERVICES)
|
||||||
headers=HA_HEADERS)
|
data = yield from resp.json()
|
||||||
|
local_services = hass.services.async_services()
|
||||||
|
|
||||||
local_services = hass.services.services
|
for serv_domain in data:
|
||||||
|
|
||||||
for serv_domain in req.json():
|
|
||||||
local = local_services.pop(serv_domain["domain"])
|
local = local_services.pop(serv_domain["domain"])
|
||||||
|
|
||||||
self.assertEqual(local, serv_domain["services"])
|
assert serv_domain["services"] == local
|
||||||
|
|
||||||
def test_api_call_service_no_data(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_call_service_no_data(hass, mock_api_client):
|
||||||
"""Test if the API allows us to call a service."""
|
"""Test if the API allows us to call a service."""
|
||||||
test_value = []
|
test_value = []
|
||||||
|
|
||||||
@ -278,18 +259,17 @@ class TestAPI(unittest.TestCase):
|
|||||||
"""Helper method that will verify that our service got called."""
|
"""Helper method that will verify that our service got called."""
|
||||||
test_value.append(1)
|
test_value.append(1)
|
||||||
|
|
||||||
hass.services.register("test_domain", "test_service", listener)
|
hass.services.async_register("test_domain", "test_service", listener)
|
||||||
|
|
||||||
requests.post(
|
yield from mock_api_client.post(
|
||||||
_url(const.URL_API_SERVICES_SERVICE.format(
|
const.URL_API_SERVICES_SERVICE.format(
|
||||||
"test_domain", "test_service")),
|
"test_domain", "test_service"))
|
||||||
headers=HA_HEADERS)
|
yield from hass.async_block_till_done()
|
||||||
|
assert len(test_value) == 1
|
||||||
|
|
||||||
hass.block_till_done()
|
|
||||||
|
|
||||||
self.assertEqual(1, len(test_value))
|
@asyncio.coroutine
|
||||||
|
def test_api_call_service_with_data(hass, mock_api_client):
|
||||||
def test_api_call_service_with_data(self):
|
|
||||||
"""Test if the API allows us to call a service."""
|
"""Test if the API allows us to call a service."""
|
||||||
test_value = []
|
test_value = []
|
||||||
|
|
||||||
@ -302,81 +282,87 @@ class TestAPI(unittest.TestCase):
|
|||||||
if "test" in service_call.data:
|
if "test" in service_call.data:
|
||||||
test_value.append(1)
|
test_value.append(1)
|
||||||
|
|
||||||
hass.services.register("test_domain", "test_service", listener)
|
hass.services.async_register("test_domain", "test_service", listener)
|
||||||
|
|
||||||
requests.post(
|
yield from mock_api_client.post(
|
||||||
_url(const.URL_API_SERVICES_SERVICE.format(
|
const.URL_API_SERVICES_SERVICE.format("test_domain", "test_service"),
|
||||||
"test_domain", "test_service")),
|
json={"test": 1})
|
||||||
data=json.dumps({"test": 1}),
|
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
hass.block_till_done()
|
yield from hass.async_block_till_done()
|
||||||
|
assert len(test_value) == 1
|
||||||
|
|
||||||
self.assertEqual(1, len(test_value))
|
|
||||||
|
|
||||||
def test_api_template(self):
|
@asyncio.coroutine
|
||||||
|
def test_api_template(hass, mock_api_client):
|
||||||
"""Test the template API."""
|
"""Test the template API."""
|
||||||
hass.states.set('sensor.temperature', 10)
|
hass.states.async_set('sensor.temperature', 10)
|
||||||
|
|
||||||
req = requests.post(
|
resp = yield from mock_api_client.post(
|
||||||
_url(const.URL_API_TEMPLATE),
|
const.URL_API_TEMPLATE,
|
||||||
json={"template": '{{ states.sensor.temperature.state }}'},
|
json={"template": '{{ states.sensor.temperature.state }}'})
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
self.assertEqual('10', req.text)
|
body = yield from resp.text()
|
||||||
|
|
||||||
def test_api_template_error(self):
|
assert body == '10'
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_api_template_error(hass, mock_api_client):
|
||||||
"""Test the template API."""
|
"""Test the template API."""
|
||||||
hass.states.set('sensor.temperature', 10)
|
hass.states.async_set('sensor.temperature', 10)
|
||||||
|
|
||||||
req = requests.post(
|
resp = yield from mock_api_client.post(
|
||||||
_url(const.URL_API_TEMPLATE),
|
const.URL_API_TEMPLATE,
|
||||||
data=json.dumps({"template":
|
json={"template": '{{ states.sensor.temperature.state'})
|
||||||
'{{ states.sensor.temperature.state'}),
|
|
||||||
headers=HA_HEADERS)
|
|
||||||
|
|
||||||
self.assertEqual(400, req.status_code)
|
assert resp.status == 400
|
||||||
|
|
||||||
def test_stream(self):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_stream(hass, mock_api_client):
|
||||||
"""Test the stream."""
|
"""Test the stream."""
|
||||||
listen_count = self._listen_count()
|
listen_count = _listen_count(hass)
|
||||||
with closing(requests.get(_url(const.URL_API_STREAM), timeout=3,
|
|
||||||
stream=True, headers=HA_HEADERS)) as req:
|
|
||||||
stream = req.iter_content(1)
|
|
||||||
self.assertEqual(listen_count + 1, self._listen_count())
|
|
||||||
|
|
||||||
hass.bus.fire('test_event')
|
resp = yield from mock_api_client.get(const.URL_API_STREAM)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert listen_count + 1 == _listen_count(hass)
|
||||||
|
|
||||||
data = self._stream_next_event(stream)
|
hass.bus.async_fire('test_event')
|
||||||
|
|
||||||
self.assertEqual('test_event', data['event_type'])
|
data = yield from _stream_next_event(resp.content)
|
||||||
|
|
||||||
def test_stream_with_restricted(self):
|
assert data['event_type'] == 'test_event'
|
||||||
|
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def test_stream_with_restricted(hass, mock_api_client):
|
||||||
"""Test the stream with restrictions."""
|
"""Test the stream with restrictions."""
|
||||||
listen_count = self._listen_count()
|
listen_count = _listen_count(hass)
|
||||||
url = _url('{}?restrict=test_event1,test_event3'.format(
|
|
||||||
const.URL_API_STREAM))
|
|
||||||
with closing(requests.get(url, stream=True, timeout=3,
|
|
||||||
headers=HA_HEADERS)) as req:
|
|
||||||
stream = req.iter_content(1)
|
|
||||||
self.assertEqual(listen_count + 1, self._listen_count())
|
|
||||||
|
|
||||||
hass.bus.fire('test_event1')
|
resp = yield from mock_api_client.get(
|
||||||
data = self._stream_next_event(stream)
|
'{}?restrict=test_event1,test_event3'.format(const.URL_API_STREAM))
|
||||||
self.assertEqual('test_event1', data['event_type'])
|
assert resp.status == 200
|
||||||
|
assert listen_count + 1 == _listen_count(hass)
|
||||||
|
|
||||||
hass.bus.fire('test_event2')
|
hass.bus.async_fire('test_event1')
|
||||||
hass.bus.fire('test_event3')
|
data = yield from _stream_next_event(resp.content)
|
||||||
|
assert data['event_type'] == 'test_event1'
|
||||||
|
|
||||||
data = self._stream_next_event(stream)
|
hass.bus.async_fire('test_event2')
|
||||||
self.assertEqual('test_event3', data['event_type'])
|
hass.bus.async_fire('test_event3')
|
||||||
|
data = yield from _stream_next_event(resp.content)
|
||||||
|
assert data['event_type'] == 'test_event3'
|
||||||
|
|
||||||
def _stream_next_event(self, stream):
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def _stream_next_event(stream):
|
||||||
"""Read the stream for next event while ignoring ping."""
|
"""Read the stream for next event while ignoring ping."""
|
||||||
while True:
|
while True:
|
||||||
data = b''
|
|
||||||
last_new_line = False
|
last_new_line = False
|
||||||
for dat in stream:
|
data = b''
|
||||||
|
|
||||||
|
while True:
|
||||||
|
dat = yield from stream.read(1)
|
||||||
if dat == b'\n' and last_new_line:
|
if dat == b'\n' and last_new_line:
|
||||||
break
|
break
|
||||||
data += dat
|
data += dat
|
||||||
@ -386,9 +372,9 @@ class TestAPI(unittest.TestCase):
|
|||||||
|
|
||||||
if conv != 'ping':
|
if conv != 'ping':
|
||||||
break
|
break
|
||||||
|
|
||||||
return json.loads(conv)
|
return json.loads(conv)
|
||||||
|
|
||||||
def _listen_count(self):
|
|
||||||
|
def _listen_count(hass):
|
||||||
"""Return number of event listeners."""
|
"""Return number of event listeners."""
|
||||||
return sum(hass.bus.listeners.values())
|
return sum(hass.bus.async_listeners().values())
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""The tests device sun light trigger component."""
|
"""The tests device sun light trigger component."""
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -12,32 +11,7 @@ from homeassistant.components import (
|
|||||||
device_tracker, light, device_sun_light_trigger)
|
device_tracker, light, device_sun_light_trigger)
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import get_test_home_assistant, fire_time_changed
|
||||||
get_test_config_dir, get_test_home_assistant, fire_time_changed)
|
|
||||||
|
|
||||||
|
|
||||||
KNOWN_DEV_YAML_PATH = os.path.join(get_test_config_dir(),
|
|
||||||
device_tracker.YAML_DEVICES)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def setUpModule():
|
|
||||||
"""Write a device tracker known devices file to be used."""
|
|
||||||
device_tracker.update_config(
|
|
||||||
KNOWN_DEV_YAML_PATH, 'device_1', device_tracker.Device(
|
|
||||||
None, None, True, 'device_1', 'DEV1',
|
|
||||||
picture='http://example.com/dev1.jpg'))
|
|
||||||
|
|
||||||
device_tracker.update_config(
|
|
||||||
KNOWN_DEV_YAML_PATH, 'device_2', device_tracker.Device(
|
|
||||||
None, None, True, 'device_2', 'DEV2',
|
|
||||||
picture='http://example.com/dev2.jpg'))
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
def tearDownModule():
|
|
||||||
"""Remove device tracker known devices file."""
|
|
||||||
os.remove(KNOWN_DEV_YAML_PATH)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDeviceSunLightTrigger(unittest.TestCase):
|
class TestDeviceSunLightTrigger(unittest.TestCase):
|
||||||
@ -55,6 +29,25 @@ class TestDeviceSunLightTrigger(unittest.TestCase):
|
|||||||
|
|
||||||
loader.get_component('light.test').init()
|
loader.get_component('light.test').init()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
'homeassistant.components.device_tracker.load_yaml_config_file',
|
||||||
|
return_value={
|
||||||
|
'device_1': {
|
||||||
|
'hide_if_away': False,
|
||||||
|
'mac': 'DEV1',
|
||||||
|
'name': 'Unnamed Device',
|
||||||
|
'picture': 'http://example.com/dev1.jpg',
|
||||||
|
'track': True,
|
||||||
|
'vendor': None
|
||||||
|
},
|
||||||
|
'device_2': {
|
||||||
|
'hide_if_away': False,
|
||||||
|
'mac': 'DEV2',
|
||||||
|
'name': 'Unnamed Device',
|
||||||
|
'picture': 'http://example.com/dev2.jpg',
|
||||||
|
'track': True,
|
||||||
|
'vendor': None}
|
||||||
|
}):
|
||||||
self.assertTrue(setup_component(self.hass, device_tracker.DOMAIN, {
|
self.assertTrue(setup_component(self.hass, device_tracker.DOMAIN, {
|
||||||
device_tracker.DOMAIN: {CONF_PLATFORM: 'test'}
|
device_tracker.DOMAIN: {CONF_PLATFORM: 'test'}
|
||||||
}))
|
}))
|
||||||
|
@ -8,10 +8,10 @@ from homeassistant.setup import async_setup_component
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_http_client(loop, hass, test_client):
|
def mock_http_client(hass, test_client):
|
||||||
"""Start the Hass HTTP component."""
|
"""Start the Hass HTTP component."""
|
||||||
loop.run_until_complete(async_setup_component(hass, 'frontend', {}))
|
hass.loop.run_until_complete(async_setup_component(hass, 'frontend', {}))
|
||||||
return loop.run_until_complete(test_client(hass.http.app))
|
return hass.loop.run_until_complete(test_client(hass.http.app))
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
|
Loading…
x
Reference in New Issue
Block a user