diff --git a/homeassistant/components/emulated_hue/__init__.py b/homeassistant/components/emulated_hue/__init__.py index 9fba21b81dc..c89e4fda358 100644 --- a/homeassistant/components/emulated_hue/__init__.py +++ b/homeassistant/components/emulated_hue/__init__.py @@ -14,7 +14,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, ) from homeassistant.components.http import REQUIREMENTS # NOQA -from homeassistant.components.http import HomeAssistantWSGI +from homeassistant.components.http import HomeAssistantHTTP from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.deprecation import get_deprecated import homeassistant.helpers.config_validation as cv @@ -86,7 +86,7 @@ def setup(hass, yaml_config): """Activate the emulated_hue component.""" config = Config(hass, yaml_config.get(DOMAIN, {})) - server = HomeAssistantWSGI( + server = HomeAssistantHTTP( hass, server_host=config.host_ip_addr, server_port=config.listen_port, diff --git a/homeassistant/components/frontend/__init__.py b/homeassistant/components/frontend/__init__.py index c426a775fc5..7fa1634778d 100644 --- a/homeassistant/components/frontend/__init__.py +++ b/homeassistant/components/frontend/__init__.py @@ -17,7 +17,7 @@ import jinja2 import homeassistant.helpers.config_validation as cv from homeassistant.components.http import HomeAssistantView -from homeassistant.components.http.auth import is_trusted_ip +from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.config import find_config_file, load_yaml_config_file from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED from homeassistant.core import callback @@ -490,7 +490,7 @@ class IndexView(HomeAssistantView): panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5 no_auth = '1' - if hass.config.api.api_password and not is_trusted_ip(request): + if hass.config.api.api_password and not request[KEY_AUTHENTICATED]: # do not try to auto connect on load no_auth = '0' diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 22f8c90dfb1..ac253b2821a 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -12,35 +12,28 @@ import os import ssl from aiohttp import web -from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently import voluptuous as vol from homeassistant.const import ( - SERVER_PORT, CONTENT_TYPE_JSON, HTTP_HEADER_HA_AUTH, - EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, - HTTP_HEADER_X_REQUESTED_WITH) + SERVER_PORT, CONTENT_TYPE_JSON, + EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,) from homeassistant.core import is_callback import homeassistant.helpers.config_validation as cv import homeassistant.remote as rem import homeassistant.util as hass_util from homeassistant.util.logging import HideSensitiveDataFilter -from .auth import auth_middleware -from .ban import ban_middleware -from .const import ( - KEY_BANS_ENABLED, KEY_AUTHENTICATED, KEY_LOGIN_THRESHOLD, - KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR) +from .auth import setup_auth +from .ban import setup_bans +from .cors import setup_cors +from .real_ip import setup_real_ip +from .const import KEY_AUTHENTICATED, KEY_REAL_IP from .static import ( CachingFileResponse, CachingStaticResource, staticresource_middleware) -from .util import get_real_ip REQUIREMENTS = ['aiohttp_cors==0.6.0'] -ALLOWED_CORS_HEADERS = [ - ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE, - HTTP_HEADER_HA_AUTH] - DOMAIN = 'http' CONF_API_PASSWORD = 'api_password' @@ -127,7 +120,7 @@ def async_setup(hass, config): logging.getLogger('aiohttp.access').addFilter( HideSensitiveDataFilter(api_password)) - server = HomeAssistantWSGI( + server = HomeAssistantHTTP( hass, server_host=server_host, server_port=server_port, @@ -173,25 +166,29 @@ def async_setup(hass, config): return True -class HomeAssistantWSGI(object): - """WSGI server for Home Assistant.""" +class HomeAssistantHTTP(object): + """HTTP server for Home Assistant.""" def __init__(self, hass, api_password, ssl_certificate, ssl_key, server_host, server_port, cors_origins, use_x_forwarded_for, trusted_networks, login_threshold, is_ban_enabled): - """Initialize the WSGI Home Assistant server.""" - middlewares = [auth_middleware, staticresource_middleware] + """Initialize the HTTP Home Assistant server.""" + app = self.app = web.Application( + middlewares=[staticresource_middleware]) + + # This order matters + setup_real_ip(app, use_x_forwarded_for) if is_ban_enabled: - middlewares.insert(0, ban_middleware) + setup_bans(hass, app, login_threshold) - self.app = web.Application(middlewares=middlewares) - self.app['hass'] = hass - self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for - self.app[KEY_TRUSTED_NETWORKS] = trusted_networks - self.app[KEY_BANS_ENABLED] = is_ban_enabled - self.app[KEY_LOGIN_THRESHOLD] = login_threshold + setup_auth(app, trusted_networks, api_password) + + if cors_origins: + setup_cors(app, cors_origins) + + app['hass'] = hass self.hass = hass self.api_password = api_password @@ -199,21 +196,10 @@ class HomeAssistantWSGI(object): self.ssl_key = ssl_key self.server_host = server_host self.server_port = server_port + self.is_ban_enabled = is_ban_enabled self._handler = None self.server = None - if cors_origins: - import aiohttp_cors - - self.cors = aiohttp_cors.setup(self.app, defaults={ - host: aiohttp_cors.ResourceOptions( - allow_headers=ALLOWED_CORS_HEADERS, - allow_methods='*', - ) for host in cors_origins - }) - else: - self.cors = None - def register_view(self, view): """Register a view with the WSGI server. @@ -292,15 +278,7 @@ class HomeAssistantWSGI(object): @asyncio.coroutine def start(self): """Start the WSGI server.""" - cors_added = set() - if self.cors is not None: - for route in list(self.app.router.routes()): - if hasattr(route, 'resource'): - route = route.resource - if route in cors_added: - continue - self.cors.add(route) - cors_added.add(route) + yield from self.app.startup() if self.ssl_certificate: try: @@ -420,7 +398,7 @@ def request_handler_factory(view, handler): raise HTTPUnauthorized() _LOGGER.info('Serving %s to %s (auth: %s)', - request.path, get_real_ip(request), authenticated) + request.path, request.get(KEY_REAL_IP), authenticated) result = handler(request, **request.match_info) diff --git a/homeassistant/components/http/auth.py b/homeassistant/components/http/auth.py index a6a412b6ba2..3128489437a 100644 --- a/homeassistant/components/http/auth.py +++ b/homeassistant/components/http/auth.py @@ -7,55 +7,66 @@ import logging from aiohttp import hdrs from aiohttp.web import middleware +from homeassistant.core import callback from homeassistant.const import HTTP_HEADER_HA_AUTH -from .util import get_real_ip -from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED +from .const import KEY_AUTHENTICATED, KEY_REAL_IP DATA_API_PASSWORD = 'api_password' _LOGGER = logging.getLogger(__name__) -@middleware -@asyncio.coroutine -def auth_middleware(request, handler): - """Authenticate as middleware.""" - # If no password set, just always set authenticated=True - if request.app['hass'].http.api_password is None: - request[KEY_AUTHENTICATED] = True +@callback +def setup_auth(app, trusted_networks, api_password): + """Create auth middleware for the app.""" + @middleware + @asyncio.coroutine + def auth_middleware(request, handler): + """Authenticate as middleware.""" + # If no password set, just always set authenticated=True + if api_password is None: + request[KEY_AUTHENTICATED] = True + return (yield from handler(request)) + + # Check authentication + authenticated = False + + if (HTTP_HEADER_HA_AUTH in request.headers and + hmac.compare_digest( + api_password, request.headers[HTTP_HEADER_HA_AUTH])): + # A valid auth header has been set + authenticated = True + + elif (DATA_API_PASSWORD in request.query and + hmac.compare_digest(api_password, + request.query[DATA_API_PASSWORD])): + authenticated = True + + elif (hdrs.AUTHORIZATION in request.headers and + validate_authorization_header(api_password, request)): + authenticated = True + + elif _is_trusted_ip(request, trusted_networks): + authenticated = True + + request[KEY_AUTHENTICATED] = authenticated return (yield from handler(request)) - # Check authentication - authenticated = False + @asyncio.coroutine + def auth_startup(app): + """Initialize auth middleware when app starts up.""" + app.middlewares.append(auth_middleware) - if (HTTP_HEADER_HA_AUTH in request.headers and - validate_password( - request, request.headers[HTTP_HEADER_HA_AUTH])): - # A valid auth header has been set - authenticated = True - - elif (DATA_API_PASSWORD in request.query and - validate_password(request, request.query[DATA_API_PASSWORD])): - authenticated = True - - elif (hdrs.AUTHORIZATION in request.headers and - validate_authorization_header(request)): - authenticated = True - - elif is_trusted_ip(request): - authenticated = True - - request[KEY_AUTHENTICATED] = authenticated - return (yield from handler(request)) + app.on_startup.append(auth_startup) -def is_trusted_ip(request): +def _is_trusted_ip(request, trusted_networks): """Test if request is from a trusted ip.""" - ip_addr = get_real_ip(request) + ip_addr = request[KEY_REAL_IP] - return ip_addr and any( + return any( ip_addr in trusted_network for trusted_network - in request.app[KEY_TRUSTED_NETWORKS]) + in trusted_networks) def validate_password(request, api_password): @@ -64,7 +75,7 @@ def validate_password(request, api_password): api_password, request.app['hass'].http.api_password) -def validate_authorization_header(request): +def validate_authorization_header(api_password, request): """Test an authorization header if valid password.""" if hdrs.AUTHORIZATION not in request.headers: return False @@ -80,4 +91,4 @@ def validate_authorization_header(request): if username != 'homeassistant': return False - return validate_password(request, password) + return hmac.compare_digest(api_password, password) diff --git a/homeassistant/components/http/ban.py b/homeassistant/components/http/ban.py index 8423c53716b..4c797b05b19 100644 --- a/homeassistant/components/http/ban.py +++ b/homeassistant/components/http/ban.py @@ -10,18 +10,20 @@ from aiohttp.web import middleware from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized import voluptuous as vol +from homeassistant.core import callback from homeassistant.components import persistent_notification from homeassistant.config import load_yaml_config_file from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.util.yaml import dump -from .const import ( - KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD, - KEY_FAILED_LOGIN_ATTEMPTS) -from .util import get_real_ip +from .const import KEY_REAL_IP _LOGGER = logging.getLogger(__name__) +KEY_BANNED_IPS = 'ha_banned_ips' +KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts' +KEY_LOGIN_THRESHOLD = 'ha_login_threshold' + NOTIFICATION_ID_BAN = 'ip-ban' NOTIFICATION_ID_LOGIN = 'http-login' @@ -33,21 +35,31 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({ }) +@callback +def setup_bans(hass, app, login_threshold): + """Create IP Ban middleware for the app.""" + @asyncio.coroutine + def ban_startup(app): + """Initialize bans when app starts up.""" + app.middlewares.append(ban_middleware) + app[KEY_BANNED_IPS] = yield from hass.async_add_job( + load_ip_bans_config, hass.config.path(IP_BANS_FILE)) + app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) + app[KEY_LOGIN_THRESHOLD] = login_threshold + + app.on_startup.append(ban_startup) + + @middleware @asyncio.coroutine def ban_middleware(request, handler): """IP Ban middleware.""" - if not request.app[KEY_BANS_ENABLED]: + if KEY_BANNED_IPS not in request.app: + _LOGGER.error('IP Ban middleware loaded but banned IPs not loaded') return (yield from handler(request)) - if KEY_BANNED_IPS not in request.app: - hass = request.app['hass'] - request.app[KEY_BANNED_IPS] = yield from hass.async_add_job( - load_ip_bans_config, hass.config.path(IP_BANS_FILE)) - # Verify if IP is not banned - ip_address_ = get_real_ip(request) - + ip_address_ = request[KEY_REAL_IP] is_banned = any(ip_ban.ip_address == ip_address_ for ip_ban in request.app[KEY_BANNED_IPS]) @@ -64,7 +76,7 @@ def ban_middleware(request, handler): @asyncio.coroutine def process_wrong_login(request): """Process a wrong login attempt.""" - remote_addr = get_real_ip(request) + remote_addr = request[KEY_REAL_IP] msg = ('Login attempt or request with invalid authentication ' 'from {}'.format(remote_addr)) @@ -73,13 +85,11 @@ def process_wrong_login(request): request.app['hass'], msg, 'Login attempt failed', NOTIFICATION_ID_LOGIN) - if (not request.app[KEY_BANS_ENABLED] or + # Check if ban middleware is loaded + if (KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1): return - if KEY_FAILED_LOGIN_ATTEMPTS not in request.app: - request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) - request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1 if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] > diff --git a/homeassistant/components/http/const.py b/homeassistant/components/http/const.py index 4250dd32514..e5494e945c4 100644 --- a/homeassistant/components/http/const.py +++ b/homeassistant/components/http/const.py @@ -1,11 +1,3 @@ """HTTP specific constants.""" KEY_AUTHENTICATED = 'ha_authenticated' -KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for' -KEY_TRUSTED_NETWORKS = 'ha_trusted_networks' KEY_REAL_IP = 'ha_real_ip' -KEY_BANS_ENABLED = 'ha_bans_enabled' -KEY_BANNED_IPS = 'ha_banned_ips' -KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts' -KEY_LOGIN_THRESHOLD = 'ha_login_threshold' - -HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For' diff --git a/homeassistant/components/http/cors.py b/homeassistant/components/http/cors.py new file mode 100644 index 00000000000..2eb92732d1e --- /dev/null +++ b/homeassistant/components/http/cors.py @@ -0,0 +1,43 @@ +"""Provide cors support for the HTTP component.""" +import asyncio + +from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE + +from homeassistant.const import ( + HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_HA_AUTH) + + +from homeassistant.core import callback + + +ALLOWED_CORS_HEADERS = [ + ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE, + HTTP_HEADER_HA_AUTH] + + +@callback +def setup_cors(app, origins): + """Setup cors.""" + import aiohttp_cors + + cors = aiohttp_cors.setup(app, defaults={ + host: aiohttp_cors.ResourceOptions( + allow_headers=ALLOWED_CORS_HEADERS, + allow_methods='*', + ) for host in origins + }) + + @asyncio.coroutine + def cors_startup(app): + """Initialize cors when app starts up.""" + cors_added = set() + + for route in list(app.router.routes()): + if hasattr(route, 'resource'): + route = route.resource + if route in cors_added: + continue + cors.add(route) + cors_added.add(route) + + app.on_startup.append(cors_startup) diff --git a/homeassistant/components/http/real_ip.py b/homeassistant/components/http/real_ip.py new file mode 100644 index 00000000000..1e50f33f69e --- /dev/null +++ b/homeassistant/components/http/real_ip.py @@ -0,0 +1,35 @@ +"""Middleware to fetch real IP.""" +import asyncio +from ipaddress import ip_address + +from aiohttp.web import middleware +from aiohttp.hdrs import X_FORWARDED_FOR + +from homeassistant.core import callback + +from .const import KEY_REAL_IP + + +@callback +def setup_real_ip(app, use_x_forwarded_for): + """Create IP Ban middleware for the app.""" + @middleware + @asyncio.coroutine + def real_ip_middleware(request, handler): + """Real IP middleware.""" + if (use_x_forwarded_for and + X_FORWARDED_FOR in request.headers): + request[KEY_REAL_IP] = ip_address( + request.headers.get(X_FORWARDED_FOR).split(',')[0]) + else: + request[KEY_REAL_IP] = \ + ip_address(request.transport.get_extra_info('peername')[0]) + + return (yield from handler(request)) + + @asyncio.coroutine + def app_startup(app): + """Initialize bans when app starts up.""" + app.middlewares.append(real_ip_middleware) + + app.on_startup.append(app_startup) diff --git a/homeassistant/components/http/util.py b/homeassistant/components/http/util.py deleted file mode 100644 index 359c20f4fa1..00000000000 --- a/homeassistant/components/http/util.py +++ /dev/null @@ -1,25 +0,0 @@ -"""HTTP utilities.""" -from ipaddress import ip_address - -from .const import ( - KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR) - - -def get_real_ip(request): - """Get IP address of client.""" - if KEY_REAL_IP in request: - return request[KEY_REAL_IP] - - if (request.app.get(KEY_USE_X_FORWARDED_FOR) and - HTTP_HEADER_X_FORWARDED_FOR in request.headers): - request[KEY_REAL_IP] = ip_address( - request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0]) - else: - peername = request.transport.get_extra_info('peername') - - if peername: - request[KEY_REAL_IP] = ip_address(peername[0]) - else: - request[KEY_REAL_IP] = None - - return request[KEY_REAL_IP] diff --git a/homeassistant/components/telegram_bot/webhooks.py b/homeassistant/components/telegram_bot/webhooks.py index 055f68884a6..5c293459447 100644 --- a/homeassistant/components/telegram_bot/webhooks.py +++ b/homeassistant/components/telegram_bot/webhooks.py @@ -12,7 +12,7 @@ import logging import voluptuous as vol from homeassistant.components.http import HomeAssistantView -from homeassistant.components.http.util import get_real_ip +from homeassistant.components.http.const import KEY_REAL_IP from homeassistant.components.telegram_bot import ( CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA) from homeassistant.const import ( @@ -110,7 +110,7 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity): @asyncio.coroutine def post(self, request): """Accept the POST from telegram.""" - real_ip = get_real_ip(request) + real_ip = request[KEY_REAL_IP] if not any(real_ip in net for net in self.trusted_networks): _LOGGER.warning("Access denied from %s", real_ip) return self.json_message('Access denied', HTTP_UNAUTHORIZED) diff --git a/tests/common.py b/tests/common.py index 9e4575780bc..1b79d15b319 100644 --- a/tests/common.py +++ b/tests/common.py @@ -9,8 +9,6 @@ import logging import threading from contextlib import contextmanager -from aiohttp import web - from homeassistant import core as ha, loader from homeassistant.setup import setup_component, async_setup_component from homeassistant.config import async_process_component_config @@ -25,9 +23,6 @@ from homeassistant.const import ( EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE) from homeassistant.components import mqtt, recorder -from homeassistant.components.http.auth import auth_middleware -from homeassistant.components.http.const import ( - KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS) from homeassistant.util.async import ( run_callback_threadsafe, run_coroutine_threadsafe) @@ -262,35 +257,6 @@ def mock_state_change_event(hass, new_state, old_state=None): hass.bus.fire(EVENT_STATE_CHANGED, event_data) -def mock_http_component(hass, api_password=None): - """Mock the HTTP component.""" - hass.http = MagicMock(api_password=api_password) - mock_component(hass, 'http') - hass.http.views = {} - - def mock_register_view(view): - """Store registered view.""" - if isinstance(view, type): - # Instantiate the view, if needed - view = view() - - hass.http.views[view.name] = view - - hass.http.register_view = mock_register_view - - -def mock_http_component_app(hass, api_password=None): - """Create an aiohttp.web.Application instance for testing.""" - if 'http' not in hass.config.components: - mock_http_component(hass, api_password) - app = web.Application(middlewares=[auth_middleware]) - app['hass'] = hass - app[KEY_USE_X_FORWARDED_FOR] = False - app[KEY_BANS_ENABLED] = False - app[KEY_TRUSTED_NETWORKS] = [] - return app - - @asyncio.coroutine def async_mock_mqtt_component(hass, config=None): """Mock the MQTT component.""" diff --git a/tests/components/camera/test_uvc.py b/tests/components/camera/test_uvc.py index ad7ee5f5bcb..40b4fb2d8e2 100644 --- a/tests/components/camera/test_uvc.py +++ b/tests/components/camera/test_uvc.py @@ -9,7 +9,7 @@ from uvcclient import nvr from homeassistant.setup import setup_component from homeassistant.components.camera import uvc -from tests.common import get_test_home_assistant, mock_http_component +from tests.common import get_test_home_assistant class TestUVCSetup(unittest.TestCase): @@ -18,7 +18,6 @@ class TestUVCSetup(unittest.TestCase): def setUp(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - mock_http_component(self.hass) def tearDown(self): """Stop everything that was started.""" diff --git a/tests/components/config/test_hassbian.py b/tests/components/config/test_hassbian.py index 659e5ad2448..9038ccc6aa4 100644 --- a/tests/components/config/test_hassbian.py +++ b/tests/components/config/test_hassbian.py @@ -14,7 +14,7 @@ def test_setup_check_env_prevents_load(hass, loop): with patch.dict(os.environ, clear=True), \ patch.object(config, 'SECTIONS', ['hassbian']), \ patch('homeassistant.components.http.' - 'HomeAssistantWSGI.register_view') as reg_view: + 'HomeAssistantHTTP.register_view') as reg_view: loop.run_until_complete(async_setup_component(hass, 'config', {})) assert 'config' in hass.config.components assert reg_view.called is False @@ -25,7 +25,7 @@ def test_setup_check_env_works(hass, loop): with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \ patch.object(config, 'SECTIONS', ['hassbian']), \ patch('homeassistant.components.http.' - 'HomeAssistantWSGI.register_view') as reg_view: + 'HomeAssistantHTTP.register_view') as reg_view: loop.run_until_complete(async_setup_component(hass, 'config', {})) assert 'config' in hass.config.components assert len(reg_view.mock_calls) == 2 diff --git a/tests/components/config/test_init.py b/tests/components/config/test_init.py index 6f69f886419..2d5d814ac8a 100644 --- a/tests/components/config/test_init.py +++ b/tests/components/config/test_init.py @@ -2,19 +2,11 @@ import asyncio from unittest.mock import patch -import pytest - from homeassistant.const import EVENT_COMPONENT_LOADED from homeassistant.setup import async_setup_component, ATTR_COMPONENT from homeassistant.components import config -from tests.common import mock_http_component, mock_coro, mock_component - - -@pytest.fixture(autouse=True) -def stub_http(hass): - """Stub the HTTP component.""" - mock_http_component(hass) +from tests.common import mock_coro, mock_component @asyncio.coroutine diff --git a/tests/components/config/test_zwave.py b/tests/components/config/test_zwave.py index 81800d709e3..c98385a3c32 100644 --- a/tests/components/config/test_zwave.py +++ b/tests/components/config/test_zwave.py @@ -3,28 +3,30 @@ import asyncio import json from unittest.mock import MagicMock, patch +import pytest + from homeassistant.bootstrap import async_setup_component from homeassistant.components import config from homeassistant.components.zwave import DATA_NETWORK, const -from homeassistant.components.config.zwave import ( - ZWaveNodeValueView, ZWaveNodeGroupView, ZWaveNodeConfigView, - ZWaveUserCodeView, ZWaveConfigWriteView) -from tests.common import mock_http_component_app from tests.mock.zwave import MockNode, MockValue, MockEntityValues VIEW_NAME = 'api:config:zwave:device_config' -@asyncio.coroutine -def test_get_device_config(hass, test_client): - """Test getting device config.""" +@pytest.fixture +def client(loop, hass, test_client): + """Client to communicate with Z-Wave config views.""" with patch.object(config, 'SECTIONS', ['zwave']): - yield from async_setup_component(hass, 'config', {}) + loop.run_until_complete(async_setup_component(hass, 'config', {})) - client = yield from test_client(hass.http.app) + return loop.run_until_complete(test_client(hass.http.app)) + +@asyncio.coroutine +def test_get_device_config(client): + """Test getting device config.""" def mock_read(path): """Mock reading data.""" return { @@ -47,13 +49,8 @@ def test_get_device_config(hass, test_client): @asyncio.coroutine -def test_update_device_config(hass, test_client): +def test_update_device_config(client): """Test updating device config.""" - with patch.object(config, 'SECTIONS', ['zwave']): - yield from async_setup_component(hass, 'config', {}) - - client = yield from test_client(hass.http.app) - orig_data = { 'hello.beer': { 'ignored': True, @@ -90,13 +87,8 @@ def test_update_device_config(hass, test_client): @asyncio.coroutine -def test_update_device_config_invalid_key(hass, test_client): +def test_update_device_config_invalid_key(client): """Test updating device config.""" - with patch.object(config, 'SECTIONS', ['zwave']): - yield from async_setup_component(hass, 'config', {}) - - client = yield from test_client(hass.http.app) - resp = yield from client.post( '/api/config/zwave/device_config/invalid_entity', data=json.dumps({ 'polling_intensity': 2 @@ -106,13 +98,8 @@ def test_update_device_config_invalid_key(hass, test_client): @asyncio.coroutine -def test_update_device_config_invalid_data(hass, test_client): +def test_update_device_config_invalid_data(client): """Test updating device config.""" - with patch.object(config, 'SECTIONS', ['zwave']): - yield from async_setup_component(hass, 'config', {}) - - client = yield from test_client(hass.http.app) - resp = yield from client.post( '/api/config/zwave/device_config/hello.beer', data=json.dumps({ 'invalid_option': 2 @@ -122,13 +109,8 @@ def test_update_device_config_invalid_data(hass, test_client): @asyncio.coroutine -def test_update_device_config_invalid_json(hass, test_client): +def test_update_device_config_invalid_json(client): """Test updating device config.""" - with patch.object(config, 'SECTIONS', ['zwave']): - yield from async_setup_component(hass, 'config', {}) - - client = yield from test_client(hass.http.app) - resp = yield from client.post( '/api/config/zwave/device_config/hello.beer', data='not json') @@ -136,11 +118,8 @@ def test_update_device_config_invalid_json(hass, test_client): @asyncio.coroutine -def test_get_values(hass, test_client): +def test_get_values(hass, client): """Test getting values on node.""" - app = mock_http_component_app(hass) - ZWaveNodeValueView().register(app.router) - node = MockNode(node_id=1) value = MockValue(value_id=123456, node=node, label='Test Label', instance=1, index=2, poll_intensity=4) @@ -150,8 +129,6 @@ def test_get_values(hass, test_client): values2 = MockEntityValues(primary=value2) hass.data[const.DATA_ENTITY_VALUES] = [values, values2] - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/values/1') assert resp.status == 200 @@ -168,11 +145,8 @@ def test_get_values(hass, test_client): @asyncio.coroutine -def test_get_groups(hass, test_client): +def test_get_groups(hass, client): """Test getting groupdata on node.""" - app = mock_http_component_app(hass) - ZWaveNodeGroupView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=2) node.groups.associations = 'assoc' @@ -182,8 +156,6 @@ def test_get_groups(hass, test_client): node.groups = {1: node.groups} network.nodes = {2: node} - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/groups/2') assert resp.status == 200 @@ -200,18 +172,13 @@ def test_get_groups(hass, test_client): @asyncio.coroutine -def test_get_groups_nogroups(hass, test_client): +def test_get_groups_nogroups(hass, client): """Test getting groupdata on node with no groups.""" - app = mock_http_component_app(hass) - ZWaveNodeGroupView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=2) network.nodes = {2: node} - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/groups/2') assert resp.status == 200 @@ -221,16 +188,11 @@ def test_get_groups_nogroups(hass, test_client): @asyncio.coroutine -def test_get_groups_nonode(hass, test_client): +def test_get_groups_nonode(hass, client): """Test getting groupdata on nonexisting node.""" - app = mock_http_component_app(hass) - ZWaveNodeGroupView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() network.nodes = {1: 1, 5: 5} - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/groups/2') assert resp.status == 404 @@ -240,11 +202,8 @@ def test_get_groups_nonode(hass, test_client): @asyncio.coroutine -def test_get_config(hass, test_client): +def test_get_config(hass, client): """Test getting config on node.""" - app = mock_http_component_app(hass) - ZWaveNodeConfigView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=2) value = MockValue( @@ -261,8 +220,6 @@ def test_get_config(hass, test_client): network.nodes = {2: node} node.get_values.return_value = node.values - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/config/2') assert resp.status == 200 @@ -278,19 +235,14 @@ def test_get_config(hass, test_client): @asyncio.coroutine -def test_get_config_noconfig_node(hass, test_client): +def test_get_config_noconfig_node(hass, client): """Test getting config on node without config.""" - app = mock_http_component_app(hass) - ZWaveNodeConfigView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=2) network.nodes = {2: node} node.get_values.return_value = node.values - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/config/2') assert resp.status == 200 @@ -300,16 +252,11 @@ def test_get_config_noconfig_node(hass, test_client): @asyncio.coroutine -def test_get_config_nonode(hass, test_client): +def test_get_config_nonode(hass, client): """Test getting config on nonexisting node.""" - app = mock_http_component_app(hass) - ZWaveNodeConfigView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() network.nodes = {1: 1, 5: 5} - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/config/2') assert resp.status == 404 @@ -319,16 +266,11 @@ def test_get_config_nonode(hass, test_client): @asyncio.coroutine -def test_get_usercodes_nonode(hass, test_client): +def test_get_usercodes_nonode(hass, client): """Test getting usercodes on nonexisting node.""" - app = mock_http_component_app(hass) - ZWaveUserCodeView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() network.nodes = {1: 1, 5: 5} - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/usercodes/2') assert resp.status == 404 @@ -338,11 +280,8 @@ def test_get_usercodes_nonode(hass, test_client): @asyncio.coroutine -def test_get_usercodes(hass, test_client): +def test_get_usercodes(hass, client): """Test getting usercodes on node.""" - app = mock_http_component_app(hass) - ZWaveUserCodeView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=18, command_classes=[const.COMMAND_CLASS_USER_CODE]) @@ -356,8 +295,6 @@ def test_get_usercodes(hass, test_client): network.nodes = {18: node} node.get_values.return_value = node.values - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/usercodes/18') assert resp.status == 200 @@ -369,19 +306,14 @@ def test_get_usercodes(hass, test_client): @asyncio.coroutine -def test_get_usercode_nousercode_node(hass, test_client): +def test_get_usercode_nousercode_node(hass, client): """Test getting usercodes on node without usercodes.""" - app = mock_http_component_app(hass) - ZWaveUserCodeView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=18) network.nodes = {18: node} node.get_values.return_value = node.values - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/usercodes/18') assert resp.status == 200 @@ -391,11 +323,8 @@ def test_get_usercode_nousercode_node(hass, test_client): @asyncio.coroutine -def test_get_usercodes_no_genreuser(hass, test_client): +def test_get_usercodes_no_genreuser(hass, client): """Test getting usercodes on node missing genre user.""" - app = mock_http_component_app(hass) - ZWaveUserCodeView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() node = MockNode(node_id=18, command_classes=[const.COMMAND_CLASS_USER_CODE]) @@ -409,8 +338,6 @@ def test_get_usercodes_no_genreuser(hass, test_client): network.nodes = {18: node} node.get_values.return_value = node.values - client = yield from test_client(app) - resp = yield from client.get('/api/zwave/usercodes/18') assert resp.status == 200 @@ -420,13 +347,8 @@ def test_get_usercodes_no_genreuser(hass, test_client): @asyncio.coroutine -def test_save_config_no_network(hass, test_client): +def test_save_config_no_network(hass, client): """Test saving configuration without network data.""" - app = mock_http_component_app(hass) - ZWaveConfigWriteView().register(app.router) - - client = yield from test_client(app) - resp = yield from client.post('/api/zwave/saveconfig') assert resp.status == 404 @@ -435,15 +357,10 @@ def test_save_config_no_network(hass, test_client): @asyncio.coroutine -def test_save_config(hass, test_client): +def test_save_config(hass, client): """Test saving configuration.""" - app = mock_http_component_app(hass) - ZWaveConfigWriteView().register(app.router) - network = hass.data[DATA_NETWORK] = MagicMock() - client = yield from test_client(app) - resp = yield from client.post('/api/zwave/saveconfig') assert resp.status == 200 diff --git a/tests/components/device_tracker/test_automatic.py b/tests/components/device_tracker/test_automatic.py index d40c1518ffa..d90b5c0dd62 100644 --- a/tests/components/device_tracker/test_automatic.py +++ b/tests/components/device_tracker/test_automatic.py @@ -5,11 +5,10 @@ import logging from unittest.mock import patch, MagicMock import aioautomatic +from homeassistant.setup import async_setup_component from homeassistant.components.device_tracker.automatic import ( async_setup_scanner) -from tests.common import mock_http_component - _LOGGER = logging.getLogger(__name__) @@ -23,8 +22,7 @@ def test_invalid_credentials( mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load, mock_create_session, hass): """Test with invalid credentials.""" - mock_http_component(hass) - + hass.loop.run_until_complete(async_setup_component(hass, 'http', {})) mock_json_load.return_value = {'refresh_token': 'bad_token'} @asyncio.coroutine @@ -59,8 +57,7 @@ def test_valid_credentials( mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load, mock_ws_connect, mock_create_session, hass): """Test with valid credentials.""" - mock_http_component(hass) - + hass.loop.run_until_complete(async_setup_component(hass, 'http', {})) mock_json_load.return_value = {'refresh_token': 'good_token'} session = MagicMock() diff --git a/tests/components/http/__init__.py b/tests/components/http/__init__.py index 869e80fff75..ef9817a2f1b 100644 --- a/tests/components/http/__init__.py +++ b/tests/components/http/__init__.py @@ -1 +1,38 @@ """Tests for the HTTP component.""" +import asyncio +from ipaddress import ip_address + +from aiohttp import web + +from homeassistant.components.http.const import KEY_REAL_IP + + +def mock_real_ip(app): + """Inject middleware to mock real IP. + + Returns a function to set the real IP. + """ + ip_to_mock = None + + def set_ip_to_mock(value): + nonlocal ip_to_mock + ip_to_mock = value + + @asyncio.coroutine + @web.middleware + def mock_real_ip(request, handler): + """Mock Real IP middleware.""" + nonlocal ip_to_mock + + request[KEY_REAL_IP] = ip_address(ip_to_mock) + + return (yield from handler(request)) + + @asyncio.coroutine + def real_ip_startup(app): + """Startup of real ip.""" + app.middlewares.insert(0, mock_real_ip) + + app.on_startup.append(real_ip_startup) + + return set_ip_to_mock diff --git a/tests/components/http/test_auth.py b/tests/components/http/test_auth.py index ef9c63ad09e..c2687c05a8f 100644 --- a/tests/components/http/test_auth.py +++ b/tests/components/http/test_auth.py @@ -1,195 +1,156 @@ """The tests for the Home Assistant HTTP component.""" # pylint: disable=protected-access import asyncio -from ipaddress import ip_address, ip_network +from ipaddress import ip_network from unittest.mock import patch -import aiohttp +from aiohttp import BasicAuth, web +from aiohttp.web_exceptions import HTTPUnauthorized import pytest -from homeassistant import const +from homeassistant.const import HTTP_HEADER_HA_AUTH from homeassistant.setup import async_setup_component -import homeassistant.components.http as http -from homeassistant.components.http.const import ( - KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR) +from homeassistant.components.http.auth import setup_auth +from homeassistant.components.http.real_ip import setup_real_ip +from homeassistant.components.http.const import KEY_AUTHENTICATED + +from . import mock_real_ip API_PASSWORD = 'test1234' # Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases -TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1', - 'FD01:DB8::1'] +TRUSTED_NETWORKS = [ + ip_network('192.0.2.0/24'), + ip_network('2001:DB8:ABCD::/48'), + ip_network('100.64.0.1'), + ip_network('FD01:DB8::1'), +] TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1', '2001:DB8:ABCD::1'] UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] -@pytest.fixture -def mock_api_client(hass, test_client): - """Start the Hass HTTP component.""" - hass.loop.run_until_complete(async_setup_component(hass, 'api', { - 'http': { - http.CONF_API_PASSWORD: API_PASSWORD, - } - })) - return hass.loop.run_until_complete(test_client(hass.http.app)) +@asyncio.coroutine +def mock_handler(request): + """Return if request was authenticated.""" + if not request[KEY_AUTHENTICATED]: + raise HTTPUnauthorized + return web.Response(status=200) @pytest.fixture -def mock_trusted_networks(hass, mock_api_client): - """Mock trusted networks.""" - hass.http.app[KEY_TRUSTED_NETWORKS] = [ - ip_network(trusted_network) - for trusted_network in TRUSTED_NETWORKS] +def app(): + """Fixture to setup a web.Application.""" + app = web.Application() + app.router.add_get('/', mock_handler) + setup_real_ip(app, False) + return app @asyncio.coroutine -def test_access_denied_without_password(mock_api_client): +def test_auth_middleware_loaded_by_default(hass): + """Test accessing to server from banned IP when feature is off.""" + with patch('homeassistant.components.http.setup_auth') as mock_setup: + yield from async_setup_component(hass, 'http', { + 'http': {} + }) + + assert len(mock_setup.mock_calls) == 1 + + +@asyncio.coroutine +def test_access_without_password(app, test_client): """Test access without password.""" - resp = yield from mock_api_client.get(const.URL_API) + setup_auth(app, [], None) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert resp.status == 200 + + +@asyncio.coroutine +def test_access_with_password_in_header(app, test_client): + """Test access with password in URL.""" + setup_auth(app, [], API_PASSWORD) + client = yield from test_client(app) + + req = yield from client.get( + '/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD}) + assert req.status == 200 + + req = yield from client.get( + '/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'}) + assert req.status == 401 + + +@asyncio.coroutine +def test_access_with_password_in_query(app, test_client): + """Test access without password.""" + setup_auth(app, [], API_PASSWORD) + client = yield from test_client(app) + + resp = yield from client.get('/', params={ + 'api_password': API_PASSWORD + }) + assert resp.status == 200 + + resp = yield from client.get('/') assert resp.status == 401 - -@asyncio.coroutine -def test_access_denied_with_wrong_password_in_header(mock_api_client): - """Test access with wrong password.""" - resp = yield from mock_api_client.get(const.URL_API, headers={ - const.HTTP_HEADER_HA_AUTH: 'wrongpassword' + resp = yield from client.get('/', params={ + 'api_password': 'wrong-password' }) assert resp.status == 401 @asyncio.coroutine -def test_access_denied_with_x_forwarded_for(hass, mock_api_client, - mock_trusted_networks): - """Test access denied through the X-Forwarded-For http header.""" - hass.http.use_x_forwarded_for = True - for remote_addr in UNTRUSTED_ADDRESSES: - resp = yield from mock_api_client.get(const.URL_API, headers={ - HTTP_HEADER_X_FORWARDED_FOR: remote_addr}) - - assert resp.status == 401, \ - "{} shouldn't be trusted".format(remote_addr) - - -@asyncio.coroutine -def test_access_denied_with_untrusted_ip(mock_api_client, - mock_trusted_networks): - """Test access with an untrusted ip address.""" - for remote_addr in UNTRUSTED_ADDRESSES: - with patch('homeassistant.components.http.' - 'util.get_real_ip', - return_value=ip_address(remote_addr)): - resp = yield from mock_api_client.get( - const.URL_API, params={'api_password': ''}) - - assert resp.status == 401, \ - "{} shouldn't be trusted".format(remote_addr) - - -@asyncio.coroutine -def test_access_with_password_in_header(mock_api_client, caplog): - """Test access with password in URL.""" - # Hide logging from requests package that we use to test logging - req = yield from mock_api_client.get( - const.URL_API, headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) - - assert req.status == 200 - - logs = caplog.text - - assert const.URL_API in logs - assert API_PASSWORD not in logs - - -@asyncio.coroutine -def test_access_denied_with_wrong_password_in_url(mock_api_client): - """Test access with wrong password.""" - resp = yield from mock_api_client.get( - const.URL_API, params={'api_password': 'wrongpassword'}) - - assert resp.status == 401 - - -@asyncio.coroutine -def test_access_with_password_in_url(mock_api_client, caplog): - """Test access with password in URL.""" - req = yield from mock_api_client.get( - const.URL_API, params={'api_password': API_PASSWORD}) - - assert req.status == 200 - - logs = caplog.text - - assert const.URL_API in logs - assert API_PASSWORD not in logs - - -@asyncio.coroutine -def test_access_granted_with_x_forwarded_for(hass, mock_api_client, caplog, - mock_trusted_networks): - """Test access denied through the X-Forwarded-For http header.""" - hass.http.app[KEY_USE_X_FORWARDED_FOR] = True - for remote_addr in TRUSTED_ADDRESSES: - resp = yield from mock_api_client.get(const.URL_API, headers={ - HTTP_HEADER_X_FORWARDED_FOR: remote_addr}) - - assert resp.status == 200, \ - "{} should be trusted".format(remote_addr) - - -@asyncio.coroutine -def test_access_granted_with_trusted_ip(mock_api_client, caplog, - mock_trusted_networks): - """Test access with trusted addresses.""" - for remote_addr in TRUSTED_ADDRESSES: - with patch('homeassistant.components.http.' - 'auth.get_real_ip', - return_value=ip_address(remote_addr)): - resp = yield from mock_api_client.get( - const.URL_API, params={'api_password': ''}) - - assert resp.status == 200, \ - '{} should be trusted'.format(remote_addr) - - -@asyncio.coroutine -def test_basic_auth_works(mock_api_client, caplog): +def test_basic_auth_works(app, test_client): """Test access with basic authentication.""" - req = yield from mock_api_client.get( - const.URL_API, - auth=aiohttp.BasicAuth('homeassistant', API_PASSWORD)) + setup_auth(app, [], API_PASSWORD) + client = yield from test_client(app) + req = yield from client.get( + '/', + auth=BasicAuth('homeassistant', API_PASSWORD)) assert req.status == 200 - assert const.URL_API in caplog.text - - -@asyncio.coroutine -def test_basic_auth_username_homeassistant(mock_api_client, caplog): - """Test access with basic auth requires username homeassistant.""" - req = yield from mock_api_client.get( - const.URL_API, - auth=aiohttp.BasicAuth('wrong_username', API_PASSWORD)) + req = yield from client.get( + '/', + auth=BasicAuth('wrong_username', API_PASSWORD)) assert req.status == 401 - -@asyncio.coroutine -def test_basic_auth_wrong_password(mock_api_client, caplog): - """Test access with basic auth not allowed with wrong password.""" - req = yield from mock_api_client.get( - const.URL_API, - auth=aiohttp.BasicAuth('homeassistant', 'wrong password')) - + req = yield from client.get( + '/', + auth=BasicAuth('homeassistant', 'wrong password')) assert req.status == 401 - -@asyncio.coroutine -def test_authorization_header_must_be_basic_type(mock_api_client, caplog): - """Test only basic authorization is allowed for auth header.""" - req = yield from mock_api_client.get( - const.URL_API, + req = yield from client.get( + '/', headers={ 'authorization': 'NotBasic abcdefg' }) - assert req.status == 401 + + +@asyncio.coroutine +def test_access_with_trusted_ip(test_client): + """Test access with an untrusted ip address.""" + app = web.Application() + app.router.add_get('/', mock_handler) + + setup_auth(app, TRUSTED_NETWORKS, 'some-pass') + + set_mock_ip = mock_real_ip(app) + client = yield from test_client(app) + + for remote_addr in UNTRUSTED_ADDRESSES: + set_mock_ip(remote_addr) + resp = yield from client.get('/') + assert resp.status == 401, \ + "{} shouldn't be trusted".format(remote_addr) + + for remote_addr in TRUSTED_ADDRESSES: + set_mock_ip(remote_addr) + resp = yield from client.get('/') + assert resp.status == 200, \ + "{} should be trusted".format(remote_addr) diff --git a/tests/components/http/test_ban.py b/tests/components/http/test_ban.py index c9147367c10..bd6df4f4e73 100644 --- a/tests/components/http/test_ban.py +++ b/tests/components/http/test_ban.py @@ -1,91 +1,96 @@ """The tests for the Home Assistant HTTP component.""" # pylint: disable=protected-access import asyncio -from ipaddress import ip_address from unittest.mock import patch, mock_open -import pytest +from aiohttp import web +from aiohttp.web_exceptions import HTTPUnauthorized -from homeassistant import const from homeassistant.setup import async_setup_component import homeassistant.components.http as http -from homeassistant.components.http.const import ( - KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS) -from homeassistant.components.http.ban import IpBan, IP_BANS_FILE +from homeassistant.components.http.ban import ( + IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS) + +from . import mock_real_ip -API_PASSWORD = 'test1234' BANNED_IPS = ['200.201.202.203', '100.64.0.2'] -@pytest.fixture -def mock_api_client(hass, test_client): - """Start the Hass HTTP component.""" - hass.loop.run_until_complete(async_setup_component(hass, 'api', { - 'http': { - http.CONF_API_PASSWORD: API_PASSWORD, - } - })) - hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip - in BANNED_IPS] - return hass.loop.run_until_complete(test_client(hass.http.app)) - - @asyncio.coroutine -def test_access_from_banned_ip(hass, mock_api_client): +def test_access_from_banned_ip(hass, test_client): """Test accessing to server from banned IP. Both trusted and not.""" - hass.http.app[KEY_BANS_ENABLED] = True + app = web.Application() + setup_bans(hass, app, 5) + set_real_ip = mock_real_ip(app) + + with patch('homeassistant.components.http.ban.load_ip_bans_config', + return_value=[IpBan(banned_ip) for banned_ip + in BANNED_IPS]): + client = yield from test_client(app) + for remote_addr in BANNED_IPS: - with patch('homeassistant.components.http.' - 'ban.get_real_ip', - return_value=ip_address(remote_addr)): - resp = yield from mock_api_client.get( - const.URL_API) - assert resp.status == 403 + set_real_ip(remote_addr) + resp = yield from client.get('/') + assert resp.status == 403 @asyncio.coroutine -def test_access_from_banned_ip_when_ban_is_off(hass, mock_api_client): +def test_ban_middleware_not_loaded_by_config(hass): """Test accessing to server from banned IP when feature is off.""" - hass.http.app[KEY_BANS_ENABLED] = False - for remote_addr in BANNED_IPS: - with patch('homeassistant.components.http.' - 'ban.get_real_ip', - return_value=ip_address(remote_addr)): - resp = yield from mock_api_client.get( - const.URL_API, - headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD}) - assert resp.status == 200 + with patch('homeassistant.components.http.setup_bans') as mock_setup: + yield from async_setup_component(hass, 'http', { + 'http': { + http.CONF_IP_BAN_ENABLED: False, + } + }) + + assert len(mock_setup.mock_calls) == 0 @asyncio.coroutine -def test_ip_bans_file_creation(hass, mock_api_client): +def test_ban_middleware_loaded_by_default(hass): + """Test accessing to server from banned IP when feature is off.""" + with patch('homeassistant.components.http.setup_bans') as mock_setup: + yield from async_setup_component(hass, 'http', { + 'http': {} + }) + + assert len(mock_setup.mock_calls) == 1 + + +@asyncio.coroutine +def test_ip_bans_file_creation(hass, test_client): """Testing if banned IP file created.""" - hass.http.app[KEY_BANS_ENABLED] = True - hass.http.app[KEY_LOGIN_THRESHOLD] = 1 + app = web.Application() + app['hass'] = hass + + @asyncio.coroutine + def unauth_handler(request): + """Return a mock web response.""" + raise HTTPUnauthorized + + app.router.add_get('/', unauth_handler) + setup_bans(hass, app, 1) + mock_real_ip(app)("200.201.202.204") + + with patch('homeassistant.components.http.ban.load_ip_bans_config', + return_value=[IpBan(banned_ip) for banned_ip + in BANNED_IPS]): + client = yield from test_client(app) m = mock_open() - @asyncio.coroutine - def call_server(): - with patch('homeassistant.components.http.' - 'ban.get_real_ip', - return_value=ip_address("200.201.202.204")): - resp = yield from mock_api_client.get( - const.URL_API, - headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'}) - return resp - with patch('homeassistant.components.http.ban.open', m, create=True): - resp = yield from call_server() + resp = yield from client.get('/') assert resp.status == 401 - assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) assert m.call_count == 0 - resp = yield from call_server() + resp = yield from client.get('/') assert resp.status == 401 - assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 + assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a') - resp = yield from call_server() + resp = yield from client.get('/') assert resp.status == 403 assert m.call_count == 1 diff --git a/tests/components/http/test_cors.py b/tests/components/http/test_cors.py new file mode 100644 index 00000000000..22b70e1c0c5 --- /dev/null +++ b/tests/components/http/test_cors.py @@ -0,0 +1,104 @@ +"""Test cors for the HTTP component.""" +import asyncio +from unittest.mock import patch + +from aiohttp import web +from aiohttp.hdrs import ( + ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_ALLOW_HEADERS, + ACCESS_CONTROL_REQUEST_HEADERS, + ACCESS_CONTROL_REQUEST_METHOD, + ORIGIN +) +import pytest + +from homeassistant.const import HTTP_HEADER_HA_AUTH +from homeassistant.setup import async_setup_component +from homeassistant.components.http.cors import setup_cors + + +TRUSTED_ORIGIN = 'https://home-assistant.io' + + +@asyncio.coroutine +def test_cors_middleware_not_loaded_by_default(hass): + """Test accessing to server from banned IP when feature is off.""" + with patch('homeassistant.components.http.setup_cors') as mock_setup: + yield from async_setup_component(hass, 'http', { + 'http': {} + }) + + assert len(mock_setup.mock_calls) == 0 + + +@asyncio.coroutine +def test_cors_middleware_loaded_from_config(hass): + """Test accessing to server from banned IP when feature is off.""" + with patch('homeassistant.components.http.setup_cors') as mock_setup: + yield from async_setup_component(hass, 'http', { + 'http': { + 'cors_allowed_origins': ['http://home-assistant.io'] + } + }) + + assert len(mock_setup.mock_calls) == 1 + + +@asyncio.coroutine +def mock_handler(request): + """Return if request was authenticated.""" + return web.Response(status=200) + + +@pytest.fixture +def client(loop, test_client): + """Fixture to setup a web.Application.""" + app = web.Application() + app.router.add_get('/', mock_handler) + setup_cors(app, [TRUSTED_ORIGIN]) + return loop.run_until_complete(test_client(app)) + + +@asyncio.coroutine +def test_cors_requests(client): + """Test cross origin requests.""" + req = yield from client.get('/', headers={ + ORIGIN: TRUSTED_ORIGIN + }) + assert req.status == 200 + assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \ + TRUSTED_ORIGIN + + # With password in URL + req = yield from client.get('/', params={ + 'api_password': 'some-pass' + }, headers={ + ORIGIN: TRUSTED_ORIGIN + }) + assert req.status == 200 + assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \ + TRUSTED_ORIGIN + + # With password in headers + req = yield from client.get('/', headers={ + HTTP_HEADER_HA_AUTH: 'some-pass', + ORIGIN: TRUSTED_ORIGIN + }) + assert req.status == 200 + assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \ + TRUSTED_ORIGIN + + +@asyncio.coroutine +def test_cors_preflight_allowed(client): + """Test cross origin resource sharing preflight (OPTIONS) request.""" + req = yield from client.options('/', headers={ + ORIGIN: TRUSTED_ORIGIN, + ACCESS_CONTROL_REQUEST_METHOD: 'GET', + ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access' + }) + + assert req.status == 200 + assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN + assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \ + HTTP_HEADER_HA_AUTH.upper() diff --git a/tests/components/http/test_init.py b/tests/components/http/test_init.py index 4ff87efd137..ab06b48043e 100644 --- a/tests/components/http/test_init.py +++ b/tests/components/http/test_init.py @@ -1,124 +1,10 @@ """The tests for the Home Assistant HTTP component.""" import asyncio -from aiohttp.hdrs import ( - ORIGIN, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_HEADERS, - ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS, - CONTENT_TYPE) -import requests -from tests.common import get_test_instance_port, get_test_home_assistant +from homeassistant.setup import async_setup_component -from homeassistant import const, setup import homeassistant.components.http as http -API_PASSWORD = 'test1234' -SERVER_PORT = get_test_instance_port() -HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT) -HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE) -HA_HEADERS = { - const.HTTP_HEADER_HA_AUTH: API_PASSWORD, - CONTENT_TYPE: const.CONTENT_TYPE_JSON, -} -CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE] - -hass = None - - -def _url(path=''): - """Helper method to generate URLs.""" - return HTTP_BASE_URL + path - - -# pylint: disable=invalid-name -def setUpModule(): - """Initialize a Home Assistant server.""" - global hass - - hass = get_test_home_assistant() - - setup.setup_component( - hass, http.DOMAIN, { - http.DOMAIN: { - http.CONF_API_PASSWORD: API_PASSWORD, - http.CONF_SERVER_PORT: SERVER_PORT, - http.CONF_CORS_ORIGINS: CORS_ORIGINS, - } - } - ) - - setup.setup_component(hass, 'api') - - # Registering static path as it caused CORS to blow up - hass.http.register_static_path( - '/custom_components', hass.config.path('custom_components')) - - hass.start() - - -# pylint: disable=invalid-name -def tearDownModule(): - """Stop the Home Assistant server.""" - hass.stop() - - -class TestCors: - """Test HTTP component.""" - - def test_cors_allowed_with_password_in_url(self): - """Test cross origin resource sharing with password in url.""" - req = requests.get(_url(const.URL_API), - params={'api_password': API_PASSWORD}, - headers={ORIGIN: HTTP_BASE_URL}) - - allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN - - assert req.status_code == 200 - assert req.headers.get(allow_origin) == HTTP_BASE_URL - - def test_cors_allowed_with_password_in_header(self): - """Test cross origin resource sharing with password in header.""" - headers = { - const.HTTP_HEADER_HA_AUTH: API_PASSWORD, - ORIGIN: HTTP_BASE_URL - } - req = requests.get(_url(const.URL_API), headers=headers) - - allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN - - assert req.status_code == 200 - assert req.headers.get(allow_origin) == HTTP_BASE_URL - - def test_cors_denied_without_origin_header(self): - """Test cross origin resource sharing with password in header.""" - headers = { - const.HTTP_HEADER_HA_AUTH: API_PASSWORD - } - req = requests.get(_url(const.URL_API), headers=headers) - - allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN - allow_headers = ACCESS_CONTROL_ALLOW_HEADERS - - assert req.status_code == 200 - assert allow_origin not in req.headers - assert allow_headers not in req.headers - - def test_cors_preflight_allowed(self): - """Test cross origin resource sharing preflight (OPTIONS) request.""" - headers = { - ORIGIN: HTTP_BASE_URL, - ACCESS_CONTROL_REQUEST_METHOD: 'GET', - ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access' - } - req = requests.options(_url(const.URL_API), headers=headers) - - allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN - allow_headers = ACCESS_CONTROL_ALLOW_HEADERS - - assert req.status_code == 200 - assert req.headers.get(allow_origin) == HTTP_BASE_URL - assert req.headers.get(allow_headers) == \ - const.HTTP_HEADER_HA_AUTH.upper() - class TestView(http.HomeAssistantView): """Test the HTTP views.""" @@ -133,12 +19,12 @@ class TestView(http.HomeAssistantView): @asyncio.coroutine -def test_registering_view_while_running(hass, test_client): +def test_registering_view_while_running(hass, test_client, unused_port): """Test that we can register a view while the server is running.""" - yield from setup.async_setup_component( + yield from async_setup_component( hass, http.DOMAIN, { http.DOMAIN: { - http.CONF_SERVER_PORT: get_test_instance_port(), + http.CONF_SERVER_PORT: unused_port(), } } ) @@ -151,7 +37,7 @@ def test_registering_view_while_running(hass, test_client): @asyncio.coroutine def test_api_base_url_with_domain(hass): """Test setting API URL.""" - result = yield from setup.async_setup_component(hass, 'http', { + result = yield from async_setup_component(hass, 'http', { 'http': { 'base_url': 'example.com' } @@ -163,7 +49,7 @@ def test_api_base_url_with_domain(hass): @asyncio.coroutine def test_api_base_url_with_ip(hass): """Test setting api url.""" - result = yield from setup.async_setup_component(hass, 'http', { + result = yield from async_setup_component(hass, 'http', { 'http': { 'server_host': '1.1.1.1' } @@ -175,7 +61,7 @@ def test_api_base_url_with_ip(hass): @asyncio.coroutine def test_api_base_url_with_ip_port(hass): """Test setting api url.""" - result = yield from setup.async_setup_component(hass, 'http', { + result = yield from async_setup_component(hass, 'http', { 'http': { 'base_url': '1.1.1.1:8124' } @@ -187,9 +73,34 @@ def test_api_base_url_with_ip_port(hass): @asyncio.coroutine def test_api_no_base_url(hass): """Test setting api url.""" - result = yield from setup.async_setup_component(hass, 'http', { + result = yield from async_setup_component(hass, 'http', { 'http': { } }) assert result assert hass.config.api.base_url == 'http://127.0.0.1:8123' + + +@asyncio.coroutine +def test_not_log_password(hass, unused_port, test_client, caplog): + """Test access with password doesn't get logged.""" + result = yield from async_setup_component(hass, 'api', { + 'http': { + http.CONF_SERVER_PORT: unused_port(), + http.CONF_API_PASSWORD: 'some-pass' + } + }) + assert result + + client = yield from test_client(hass.http.app) + + resp = yield from client.get('/api/', params={ + 'api_password': 'some-pass' + }) + + assert resp.status == 200 + logs = caplog.text + + # Ensure we don't log API passwords + assert '/api/' in logs + assert 'some-pass' not in logs diff --git a/tests/components/http/test_real_ip.py b/tests/components/http/test_real_ip.py new file mode 100644 index 00000000000..90201ab4c10 --- /dev/null +++ b/tests/components/http/test_real_ip.py @@ -0,0 +1,48 @@ +"""Test real IP middleware.""" +import asyncio + +from aiohttp import web +from aiohttp.hdrs import X_FORWARDED_FOR + +from homeassistant.components.http.real_ip import setup_real_ip +from homeassistant.components.http.const import KEY_REAL_IP + + +@asyncio.coroutine +def mock_handler(request): + """Handler that returns the real IP as text.""" + return web.Response(text=str(request[KEY_REAL_IP])) + + +@asyncio.coroutine +def test_ignore_x_forwarded_for(test_client): + """Test that we get the IP from the transport.""" + app = web.Application() + app.router.add_get('/', mock_handler) + setup_real_ip(app, False) + + mock_api_client = yield from test_client(app) + + resp = yield from mock_api_client.get('/', headers={ + X_FORWARDED_FOR: '255.255.255.255' + }) + assert resp.status == 200 + text = yield from resp.text() + assert text != '255.255.255.255' + + +@asyncio.coroutine +def test_use_x_forwarded_for(test_client): + """Test that we get the IP from the transport.""" + app = web.Application() + app.router.add_get('/', mock_handler) + setup_real_ip(app, True) + + mock_api_client = yield from test_client(app) + + resp = yield from mock_api_client.get('/', headers={ + X_FORWARDED_FOR: '255.255.255.255' + }) + assert resp.status == 200 + text = yield from resp.text() + assert text == '255.255.255.255' diff --git a/tests/components/mqtt/test_server.py b/tests/components/mqtt/test_server.py index 7ce9ec00797..9b4c0c69ac6 100644 --- a/tests/components/mqtt/test_server.py +++ b/tests/components/mqtt/test_server.py @@ -4,8 +4,7 @@ from unittest.mock import Mock, MagicMock, patch from homeassistant.setup import setup_component import homeassistant.components.mqtt as mqtt -from tests.common import ( - get_test_home_assistant, mock_coro, mock_http_component) +from tests.common import get_test_home_assistant, mock_coro class TestMQTT: @@ -14,7 +13,9 @@ class TestMQTT: def setup_method(self, method): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - mock_http_component(self.hass, 'super_secret') + setup_component(self.hass, 'http', { + 'api_password': 'super_secret' + }) def teardown_method(self, method): """Stop everything that was started.""" diff --git a/tests/components/notify/test_html5.py b/tests/components/notify/test_html5.py index 6fb2e6454de..d6c06f77d93 100644 --- a/tests/components/notify/test_html5.py +++ b/tests/components/notify/test_html5.py @@ -4,12 +4,10 @@ import json from unittest.mock import patch, MagicMock, mock_open from aiohttp.hdrs import AUTHORIZATION +from homeassistant.setup import async_setup_component from homeassistant.exceptions import HomeAssistantError -from homeassistant.util.json import save_json from homeassistant.components.notify import html5 -from tests.common import mock_http_component_app - CONFIG_FILE = 'file.conf' SUBSCRIPTION_1 = { @@ -52,6 +50,23 @@ REGISTER_URL = '/api/notify.html5' PUBLISH_URL = '/api/notify.html5/callback' +@asyncio.coroutine +def mock_client(hass, test_client, registrations=None): + """Create a test client for HTML5 views.""" + if registrations is None: + registrations = {} + + with patch('homeassistant.components.notify.html5._load_config', + return_value=registrations): + yield from async_setup_component(hass, 'notify', { + 'notify': { + 'platform': 'html5' + } + }) + + return (yield from test_client(hass.http.app)) + + class TestHtml5Notify(object): """Tests for HTML5 notify platform.""" @@ -89,8 +104,6 @@ class TestHtml5Notify(object): service.send_message('Hello', target=['device', 'non_existing'], data={'icon': 'beer.png'}) - print(mock_wp.mock_calls) - assert len(mock_wp.mock_calls) == 3 # WebPusher constructor @@ -104,421 +117,224 @@ class TestHtml5Notify(object): assert payload['body'] == 'Hello' assert payload['icon'] == 'beer.png' - @asyncio.coroutine - def test_registering_new_device_view(self, loop, test_client): - """Test that the HTML view works.""" - hass = MagicMock() - expected = { - 'unnamed device': SUBSCRIPTION_1, - } - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) +@asyncio.coroutine +def test_registering_new_device_view(hass, test_client): + """Test that the HTML view works.""" + client = yield from mock_client(hass, test_client) - assert service is not None - - assert len(hass.mock_calls) == 3 - - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == {} - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False + with patch('homeassistant.components.notify.html5.save_json') as mock_save: resp = yield from client.post(REGISTER_URL, data=json.dumps(SUBSCRIPTION_1)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected + assert resp.status == 200 + assert len(mock_save.mock_calls) == 1 + assert mock_save.mock_calls[0][1][1] == { + 'unnamed device': SUBSCRIPTION_1, + } - hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) - @asyncio.coroutine - def test_registering_new_device_expiration_view(self, loop, test_client): - """Test that the HTML view works.""" - hass = MagicMock() - expected = { - 'unnamed device': SUBSCRIPTION_4, - } +@asyncio.coroutine +def test_registering_new_device_expiration_view(hass, test_client): + """Test that the HTML view works.""" + client = yield from mock_client(hass, test_client) - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) - - assert service is not None - - # assert hass.called - assert len(hass.mock_calls) == 3 - - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == {} - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False + with patch('homeassistant.components.notify.html5.save_json') as mock_save: resp = yield from client.post(REGISTER_URL, data=json.dumps(SUBSCRIPTION_4)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected + assert resp.status == 200 + assert mock_save.mock_calls[0][1][1] == { + 'unnamed device': SUBSCRIPTION_4, + } - hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) - @asyncio.coroutine - def test_registering_new_device_fails_view(self, loop, test_client): - """Test subs. are not altered when registering a new device fails.""" - hass = MagicMock() - expected = {} - - hass.config.path.return_value = CONFIG_FILE - html5.get_service(hass, {}) - view = hass.mock_calls[1][1][0] - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - - hass.async_add_job.side_effect = HomeAssistantError() +@asyncio.coroutine +def test_registering_new_device_fails_view(hass, test_client): + """Test subs. are not altered when registering a new device fails.""" + registrations = {} + client = yield from mock_client(hass, test_client, registrations) + with patch('homeassistant.components.notify.html5.save_json', + side_effect=HomeAssistantError()): resp = yield from client.post(REGISTER_URL, - data=json.dumps(SUBSCRIPTION_1)) + data=json.dumps(SUBSCRIPTION_4)) - content = yield from resp.text() - assert resp.status == 500, content - assert view.registrations == expected + assert resp.status == 500 + assert registrations == {} - @asyncio.coroutine - def test_registering_existing_device_view(self, loop, test_client): - """Test subscription is updated when registering existing device.""" - hass = MagicMock() - expected = { - 'unnamed device': SUBSCRIPTION_4, - } - hass.config.path.return_value = CONFIG_FILE - html5.get_service(hass, {}) - view = hass.mock_calls[1][1][0] - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False +@asyncio.coroutine +def test_registering_existing_device_view(hass, test_client): + """Test subscription is updated when registering existing device.""" + registrations = {} + client = yield from mock_client(hass, test_client, registrations) + with patch('homeassistant.components.notify.html5.save_json') as mock_save: yield from client.post(REGISTER_URL, data=json.dumps(SUBSCRIPTION_1)) resp = yield from client.post(REGISTER_URL, data=json.dumps(SUBSCRIPTION_4)) - content = yield from resp.text() - assert resp.status == 200, content - assert view.registrations == expected + assert resp.status == 200 + assert mock_save.mock_calls[0][1][1] == { + 'unnamed device': SUBSCRIPTION_4, + } + assert registrations == { + 'unnamed device': SUBSCRIPTION_4, + } - hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected) - @asyncio.coroutine - def test_registering_existing_device_fails_view(self, loop, test_client): - """Test sub. is not updated when registering existing device fails.""" - hass = MagicMock() - expected = { - 'unnamed device': SUBSCRIPTION_1, - } - - hass.config.path.return_value = CONFIG_FILE - html5.get_service(hass, {}) - view = hass.mock_calls[1][1][0] - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False +@asyncio.coroutine +def test_registering_existing_device_fails_view(hass, test_client): + """Test sub. is not updated when registering existing device fails.""" + registrations = {} + client = yield from mock_client(hass, test_client, registrations) + with patch('homeassistant.components.notify.html5.save_json') as mock_save: yield from client.post(REGISTER_URL, data=json.dumps(SUBSCRIPTION_1)) - - hass.async_add_job.side_effect = HomeAssistantError() + mock_save.side_effect = HomeAssistantError resp = yield from client.post(REGISTER_URL, data=json.dumps(SUBSCRIPTION_4)) - content = yield from resp.text() - assert resp.status == 500, content - assert view.registrations == expected + assert resp.status == 500 + assert registrations == { + 'unnamed device': SUBSCRIPTION_1, + } - @asyncio.coroutine - def test_registering_new_device_validation(self, loop, test_client): - """Test various errors when registering a new device.""" - hass = MagicMock() - m = mock_open() - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) +@asyncio.coroutine +def test_registering_new_device_validation(hass, test_client): + """Test various errors when registering a new device.""" + client = yield from mock_client(hass, test_client) - assert service is not None + resp = yield from client.post(REGISTER_URL, data=json.dumps({ + 'browser': 'invalid browser', + 'subscription': 'sub info', + })) + assert resp.status == 400 - # assert hass.called - assert len(hass.mock_calls) == 3 + resp = yield from client.post(REGISTER_URL, data=json.dumps({ + 'browser': 'chrome', + })) + assert resp.status == 400 - view = hass.mock_calls[1][1][0] + with patch('homeassistant.components.notify.html5.save_json', + return_value=False): + resp = yield from client.post(REGISTER_URL, data=json.dumps({ + 'browser': 'chrome', + 'subscription': 'sub info', + })) + assert resp.status == 400 - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - resp = yield from client.post(REGISTER_URL, data=json.dumps({ - 'browser': 'invalid browser', - 'subscription': 'sub info', - })) - assert resp.status == 400 +@asyncio.coroutine +def test_unregistering_device_view(hass, test_client): + """Test that the HTML unregister view works.""" + registrations = { + 'some device': SUBSCRIPTION_1, + 'other device': SUBSCRIPTION_2, + } + client = yield from mock_client(hass, test_client, registrations) - resp = yield from client.post(REGISTER_URL, data=json.dumps({ - 'browser': 'chrome', - })) - assert resp.status == 400 + with patch('homeassistant.components.notify.html5.save_json') as mock_save: + resp = yield from client.delete(REGISTER_URL, data=json.dumps({ + 'subscription': SUBSCRIPTION_1['subscription'], + })) - with patch('homeassistant.components.notify.html5.save_json', - return_value=False): - # resp = view.post(Request(builder.get_environ())) - resp = yield from client.post(REGISTER_URL, data=json.dumps({ - 'browser': 'chrome', - 'subscription': 'sub info', - })) + assert resp.status == 200 + assert len(mock_save.mock_calls) == 1 + assert registrations == { + 'other device': SUBSCRIPTION_2 + } - assert resp.status == 400 - @asyncio.coroutine - def test_unregistering_device_view(self, loop, test_client): - """Test that the HTML unregister view works.""" - hass = MagicMock() +@asyncio.coroutine +def test_unregister_device_view_handle_unknown_subscription(hass, test_client): + """Test that the HTML unregister view handles unknown subscriptions.""" + registrations = {} + client = yield from mock_client(hass, test_client, registrations) - config = { - 'some device': SUBSCRIPTION_1, - 'other device': SUBSCRIPTION_2, - } + with patch('homeassistant.components.notify.html5.save_json') as mock_save: + resp = yield from client.delete(REGISTER_URL, data=json.dumps({ + 'subscription': SUBSCRIPTION_3['subscription'] + })) - m = mock_open(read_data=json.dumps(config)) - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) + assert resp.status == 200, resp.response + assert registrations == {} + assert len(mock_save.mock_calls) == 0 - assert service is not None - # assert hass.called - assert len(hass.mock_calls) == 3 +@asyncio.coroutine +def test_unregistering_device_view_handles_save_error(hass, test_client): + """Test that the HTML unregister view handles save errors.""" + registrations = { + 'some device': SUBSCRIPTION_1, + 'other device': SUBSCRIPTION_2, + } + client = yield from mock_client(hass, test_client, registrations) - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == config + with patch('homeassistant.components.notify.html5.save_json', + side_effect=HomeAssistantError()): + resp = yield from client.delete(REGISTER_URL, data=json.dumps({ + 'subscription': SUBSCRIPTION_1['subscription'], + })) - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False + assert resp.status == 500, resp.response + assert registrations == { + 'some device': SUBSCRIPTION_1, + 'other device': SUBSCRIPTION_2, + } - resp = yield from client.delete(REGISTER_URL, data=json.dumps({ - 'subscription': SUBSCRIPTION_1['subscription'], - })) - config.pop('some device') +@asyncio.coroutine +def test_callback_view_no_jwt(hass, test_client): + """Test that the notification callback view works without JWT.""" + client = yield from mock_client(hass, test_client) + resp = yield from client.post(PUBLISH_URL, data=json.dumps({ + 'type': 'push', + 'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72' + })) - assert resp.status == 200, resp.response - assert view.registrations == config + assert resp.status == 401, resp.response - hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, - config) - @asyncio.coroutine - def test_unregister_device_view_handle_unknown_subscription( - self, loop, test_client): - """Test that the HTML unregister view handles unknown subscriptions.""" - hass = MagicMock() +@asyncio.coroutine +def test_callback_view_with_jwt(hass, test_client): + """Test that the notification callback view works with JWT.""" + registrations = { + 'device': SUBSCRIPTION_1 + } + client = yield from mock_client(hass, test_client, registrations) - config = { - 'some device': SUBSCRIPTION_1, - 'other device': SUBSCRIPTION_2, - } + with patch('pywebpush.WebPusher') as mock_wp: + yield from hass.services.async_call('notify', 'notify', { + 'message': 'Hello', + 'target': ['device'], + 'data': {'icon': 'beer.png'} + }, blocking=True) - m = mock_open(read_data=json.dumps(config)) - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) + assert len(mock_wp.mock_calls) == 3 - assert service is not None + # WebPusher constructor + assert mock_wp.mock_calls[0][1][0] == \ + SUBSCRIPTION_1['subscription'] + # Third mock_call checks the status_code of the response. + assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__' - # assert hass.called - assert len(hass.mock_calls) == 3 + # Call to send + push_payload = json.loads(mock_wp.mock_calls[1][1][0]) - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == config + assert push_payload['body'] == 'Hello' + assert push_payload['icon'] == 'beer.png' - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False + bearer_token = "Bearer {}".format(push_payload['data']['jwt']) - resp = yield from client.delete(REGISTER_URL, data=json.dumps({ - 'subscription': SUBSCRIPTION_3['subscription'] - })) + resp = yield from client.post(PUBLISH_URL, json={ + 'type': 'push', + }, headers={AUTHORIZATION: bearer_token}) - assert resp.status == 200, resp.response - assert view.registrations == config - - hass.async_add_job.assert_not_called() - - @asyncio.coroutine - def test_unregistering_device_view_handles_save_error( - self, loop, test_client): - """Test that the HTML unregister view handles save errors.""" - hass = MagicMock() - - config = { - 'some device': SUBSCRIPTION_1, - 'other device': SUBSCRIPTION_2, - } - - m = mock_open(read_data=json.dumps(config)) - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) - - assert service is not None - - # assert hass.called - assert len(hass.mock_calls) == 3 - - view = hass.mock_calls[1][1][0] - assert view.json_path == hass.config.path.return_value - assert view.registrations == config - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - - hass.async_add_job.side_effect = HomeAssistantError() - resp = yield from client.delete(REGISTER_URL, data=json.dumps({ - 'subscription': SUBSCRIPTION_1['subscription'], - })) - - assert resp.status == 500, resp.response - assert view.registrations == config - - @asyncio.coroutine - def test_callback_view_no_jwt(self, loop, test_client): - """Test that the notification callback view works without JWT.""" - hass = MagicMock() - - m = mock_open() - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {}) - - assert service is not None - - # assert hass.called - assert len(hass.mock_calls) == 3 - - view = hass.mock_calls[2][1][0] - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - - resp = yield from client.post(PUBLISH_URL, data=json.dumps({ - 'type': 'push', - 'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72' - })) - - assert resp.status == 401, resp.response - - @asyncio.coroutine - def test_callback_view_with_jwt(self, loop, test_client): - """Test that the notification callback view works with JWT.""" - hass = MagicMock() - - data = { - 'device': SUBSCRIPTION_1 - } - - m = mock_open(read_data=json.dumps(data)) - with patch( - 'homeassistant.util.json.open', - m, create=True - ): - hass.config.path.return_value = CONFIG_FILE - service = html5.get_service(hass, {'gcm_sender_id': '100'}) - - assert service is not None - - # assert hass.called - assert len(hass.mock_calls) == 3 - - with patch('pywebpush.WebPusher') as mock_wp: - service.send_message( - 'Hello', target=['device'], data={'icon': 'beer.png'}) - - assert len(mock_wp.mock_calls) == 3 - - # WebPusher constructor - assert mock_wp.mock_calls[0][1][0] == \ - SUBSCRIPTION_1['subscription'] - # Third mock_call checks the status_code of the response. - assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__' - - # Call to send - push_payload = json.loads(mock_wp.mock_calls[1][1][0]) - - assert push_payload['body'] == 'Hello' - assert push_payload['icon'] == 'beer.png' - - view = hass.mock_calls[2][1][0] - view.registrations = data - - bearer_token = "Bearer {}".format(push_payload['data']['jwt']) - - hass.loop = loop - app = mock_http_component_app(hass) - view.register(app.router) - client = yield from test_client(app) - hass.http.is_banned_ip.return_value = False - - resp = yield from client.post(PUBLISH_URL, data=json.dumps({ - 'type': 'push', - }), headers={AUTHORIZATION: bearer_token}) - - assert resp.status == 200 - body = yield from resp.json() - assert body == {"event": "push", "status": "ok"} + assert resp.status == 200 + body = yield from resp.json() + assert body == {"event": "push", "status": "ok"} diff --git a/tests/components/test_history.py b/tests/components/test_history.py index 8484e2c536f..0c6995cc1ad 100644 --- a/tests/components/test_history.py +++ b/tests/components/test_history.py @@ -10,8 +10,7 @@ import homeassistant.util.dt as dt_util from homeassistant.components import history, recorder from tests.common import ( - init_recorder_component, mock_http_component, mock_state_change_event, - get_test_home_assistant) + init_recorder_component, mock_state_change_event, get_test_home_assistant) class TestComponentHistory(unittest.TestCase): @@ -38,7 +37,6 @@ class TestComponentHistory(unittest.TestCase): def test_setup(self): """Test setup method of history.""" - mock_http_component(self.hass) config = history.CONFIG_SCHEMA({ # ha.DOMAIN: {}, history.DOMAIN: { diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index 6a79994586c..472590ae13d 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -14,7 +14,7 @@ from homeassistant.components import logbook from homeassistant.setup import setup_component from tests.common import ( - mock_http_component, init_recorder_component, get_test_home_assistant) + init_recorder_component, get_test_home_assistant) _LOGGER = logging.getLogger(__name__) @@ -29,10 +29,7 @@ class TestComponentLogbook(unittest.TestCase): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() init_recorder_component(self.hass) # Force an in memory DB - mock_http_component(self.hass) - self.hass.config.components |= set(['frontend', 'recorder', 'api']) - assert setup_component(self.hass, logbook.DOMAIN, - self.EMPTY_CONFIG) + assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG) self.hass.start() def tearDown(self): diff --git a/tests/components/test_shopping_list.py b/tests/components/test_shopping_list.py index 2e1a03c37d0..4203f7587ae 100644 --- a/tests/components/test_shopping_list.py +++ b/tests/components/test_shopping_list.py @@ -150,7 +150,6 @@ def test_api_update_fails(hass, test_client): assert resp.status == 404 beer_id = hass.data['shopping_list'].items[0]['id'] - client = yield from test_client(hass.http.app) resp = yield from client.post( '/api/shopping_list/item/{}'.format(beer_id), json={ 'name': 123, diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py index 8b6c7494214..f85030a6892 100644 --- a/tests/components/test_websocket_api.py +++ b/tests/components/test_websocket_api.py @@ -8,8 +8,9 @@ import pytest from homeassistant.core import callback from homeassistant.components import websocket_api as wapi, frontend +from homeassistant.setup import async_setup_component -from tests.common import mock_http_component_app, mock_coro +from tests.common import mock_coro API_PASSWORD = 'test1234' @@ -17,10 +18,10 @@ API_PASSWORD = 'test1234' @pytest.fixture def websocket_client(loop, hass, test_client): """Websocket client fixture connected to websocket server.""" - websocket_app = mock_http_component_app(hass) - wapi.WebsocketAPIView().register(websocket_app.router) + assert loop.run_until_complete( + async_setup_component(hass, 'websocket_api')) - client = loop.run_until_complete(test_client(websocket_app)) + client = loop.run_until_complete(test_client(hass.http.app)) ws = loop.run_until_complete(client.ws_connect(wapi.URL)) auth_ok = loop.run_until_complete(ws.receive_json()) @@ -35,10 +36,14 @@ def websocket_client(loop, hass, test_client): @pytest.fixture def no_auth_websocket_client(hass, loop, test_client): """Websocket connection that requires authentication.""" - websocket_app = mock_http_component_app(hass, API_PASSWORD) - wapi.WebsocketAPIView().register(websocket_app.router) + assert loop.run_until_complete( + async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + })) - client = loop.run_until_complete(test_client(websocket_app)) + client = loop.run_until_complete(test_client(hass.http.app)) ws = loop.run_until_complete(client.ws_connect(wapi.URL)) auth_ok = loop.run_until_complete(ws.receive_json())