mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 08:47:57 +00:00
Cleanup http (#12424)
* Clean up HTTP component * Clean up HTTP mock * Remove unused import * Fix test * Lint
This commit is contained in:
parent
ad8fe8a93a
commit
f32911d036
@ -14,7 +14,7 @@ from homeassistant.const import (
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.components.http import REQUIREMENTS # NOQA
|
||||
from homeassistant.components.http import HomeAssistantWSGI
|
||||
from homeassistant.components.http import HomeAssistantHTTP
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.deprecation import get_deprecated
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
@ -86,7 +86,7 @@ def setup(hass, yaml_config):
|
||||
"""Activate the emulated_hue component."""
|
||||
config = Config(hass, yaml_config.get(DOMAIN, {}))
|
||||
|
||||
server = HomeAssistantWSGI(
|
||||
server = HomeAssistantHTTP(
|
||||
hass,
|
||||
server_host=config.host_ip_addr,
|
||||
server_port=config.listen_port,
|
||||
|
@ -17,7 +17,7 @@ import jinja2
|
||||
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http.auth import is_trusted_ip
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.config import find_config_file, load_yaml_config_file
|
||||
from homeassistant.const import CONF_NAME, EVENT_THEMES_UPDATED
|
||||
from homeassistant.core import callback
|
||||
@ -490,7 +490,7 @@ class IndexView(HomeAssistantView):
|
||||
panel_url = hass.data[DATA_PANELS][panel].webcomponent_url_es5
|
||||
|
||||
no_auth = '1'
|
||||
if hass.config.api.api_password and not is_trusted_ip(request):
|
||||
if hass.config.api.api_password and not request[KEY_AUTHENTICATED]:
|
||||
# do not try to auto connect on load
|
||||
no_auth = '0'
|
||||
|
||||
|
@ -12,35 +12,28 @@ import os
|
||||
import ssl
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized, HTTPMovedPermanently
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import (
|
||||
SERVER_PORT, CONTENT_TYPE_JSON, HTTP_HEADER_HA_AUTH,
|
||||
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,
|
||||
HTTP_HEADER_X_REQUESTED_WITH)
|
||||
SERVER_PORT, CONTENT_TYPE_JSON,
|
||||
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START,)
|
||||
from homeassistant.core import is_callback
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
import homeassistant.remote as rem
|
||||
import homeassistant.util as hass_util
|
||||
from homeassistant.util.logging import HideSensitiveDataFilter
|
||||
|
||||
from .auth import auth_middleware
|
||||
from .ban import ban_middleware
|
||||
from .const import (
|
||||
KEY_BANS_ENABLED, KEY_AUTHENTICATED, KEY_LOGIN_THRESHOLD,
|
||||
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR)
|
||||
from .auth import setup_auth
|
||||
from .ban import setup_bans
|
||||
from .cors import setup_cors
|
||||
from .real_ip import setup_real_ip
|
||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||
from .static import (
|
||||
CachingFileResponse, CachingStaticResource, staticresource_middleware)
|
||||
from .util import get_real_ip
|
||||
|
||||
REQUIREMENTS = ['aiohttp_cors==0.6.0']
|
||||
|
||||
ALLOWED_CORS_HEADERS = [
|
||||
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
|
||||
HTTP_HEADER_HA_AUTH]
|
||||
|
||||
DOMAIN = 'http'
|
||||
|
||||
CONF_API_PASSWORD = 'api_password'
|
||||
@ -127,7 +120,7 @@ def async_setup(hass, config):
|
||||
logging.getLogger('aiohttp.access').addFilter(
|
||||
HideSensitiveDataFilter(api_password))
|
||||
|
||||
server = HomeAssistantWSGI(
|
||||
server = HomeAssistantHTTP(
|
||||
hass,
|
||||
server_host=server_host,
|
||||
server_port=server_port,
|
||||
@ -173,25 +166,29 @@ def async_setup(hass, config):
|
||||
return True
|
||||
|
||||
|
||||
class HomeAssistantWSGI(object):
|
||||
"""WSGI server for Home Assistant."""
|
||||
class HomeAssistantHTTP(object):
|
||||
"""HTTP server for Home Assistant."""
|
||||
|
||||
def __init__(self, hass, api_password, ssl_certificate,
|
||||
ssl_key, server_host, server_port, cors_origins,
|
||||
use_x_forwarded_for, trusted_networks,
|
||||
login_threshold, is_ban_enabled):
|
||||
"""Initialize the WSGI Home Assistant server."""
|
||||
middlewares = [auth_middleware, staticresource_middleware]
|
||||
"""Initialize the HTTP Home Assistant server."""
|
||||
app = self.app = web.Application(
|
||||
middlewares=[staticresource_middleware])
|
||||
|
||||
# This order matters
|
||||
setup_real_ip(app, use_x_forwarded_for)
|
||||
|
||||
if is_ban_enabled:
|
||||
middlewares.insert(0, ban_middleware)
|
||||
setup_bans(hass, app, login_threshold)
|
||||
|
||||
self.app = web.Application(middlewares=middlewares)
|
||||
self.app['hass'] = hass
|
||||
self.app[KEY_USE_X_FORWARDED_FOR] = use_x_forwarded_for
|
||||
self.app[KEY_TRUSTED_NETWORKS] = trusted_networks
|
||||
self.app[KEY_BANS_ENABLED] = is_ban_enabled
|
||||
self.app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||
setup_auth(app, trusted_networks, api_password)
|
||||
|
||||
if cors_origins:
|
||||
setup_cors(app, cors_origins)
|
||||
|
||||
app['hass'] = hass
|
||||
|
||||
self.hass = hass
|
||||
self.api_password = api_password
|
||||
@ -199,21 +196,10 @@ class HomeAssistantWSGI(object):
|
||||
self.ssl_key = ssl_key
|
||||
self.server_host = server_host
|
||||
self.server_port = server_port
|
||||
self.is_ban_enabled = is_ban_enabled
|
||||
self._handler = None
|
||||
self.server = None
|
||||
|
||||
if cors_origins:
|
||||
import aiohttp_cors
|
||||
|
||||
self.cors = aiohttp_cors.setup(self.app, defaults={
|
||||
host: aiohttp_cors.ResourceOptions(
|
||||
allow_headers=ALLOWED_CORS_HEADERS,
|
||||
allow_methods='*',
|
||||
) for host in cors_origins
|
||||
})
|
||||
else:
|
||||
self.cors = None
|
||||
|
||||
def register_view(self, view):
|
||||
"""Register a view with the WSGI server.
|
||||
|
||||
@ -292,15 +278,7 @@ class HomeAssistantWSGI(object):
|
||||
@asyncio.coroutine
|
||||
def start(self):
|
||||
"""Start the WSGI server."""
|
||||
cors_added = set()
|
||||
if self.cors is not None:
|
||||
for route in list(self.app.router.routes()):
|
||||
if hasattr(route, 'resource'):
|
||||
route = route.resource
|
||||
if route in cors_added:
|
||||
continue
|
||||
self.cors.add(route)
|
||||
cors_added.add(route)
|
||||
yield from self.app.startup()
|
||||
|
||||
if self.ssl_certificate:
|
||||
try:
|
||||
@ -420,7 +398,7 @@ def request_handler_factory(view, handler):
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
_LOGGER.info('Serving %s to %s (auth: %s)',
|
||||
request.path, get_real_ip(request), authenticated)
|
||||
request.path, request.get(KEY_REAL_IP), authenticated)
|
||||
|
||||
result = handler(request, **request.match_info)
|
||||
|
||||
|
@ -7,55 +7,66 @@ import logging
|
||||
from aiohttp import hdrs
|
||||
from aiohttp.web import middleware
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||
from .util import get_real_ip
|
||||
from .const import KEY_TRUSTED_NETWORKS, KEY_AUTHENTICATED
|
||||
from .const import KEY_AUTHENTICATED, KEY_REAL_IP
|
||||
|
||||
DATA_API_PASSWORD = 'api_password'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def auth_middleware(request, handler):
|
||||
"""Authenticate as middleware."""
|
||||
# If no password set, just always set authenticated=True
|
||||
if request.app['hass'].http.api_password is None:
|
||||
request[KEY_AUTHENTICATED] = True
|
||||
@callback
|
||||
def setup_auth(app, trusted_networks, api_password):
|
||||
"""Create auth middleware for the app."""
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def auth_middleware(request, handler):
|
||||
"""Authenticate as middleware."""
|
||||
# If no password set, just always set authenticated=True
|
||||
if api_password is None:
|
||||
request[KEY_AUTHENTICATED] = True
|
||||
return (yield from handler(request))
|
||||
|
||||
# Check authentication
|
||||
authenticated = False
|
||||
|
||||
if (HTTP_HEADER_HA_AUTH in request.headers and
|
||||
hmac.compare_digest(
|
||||
api_password, request.headers[HTTP_HEADER_HA_AUTH])):
|
||||
# A valid auth header has been set
|
||||
authenticated = True
|
||||
|
||||
elif (DATA_API_PASSWORD in request.query and
|
||||
hmac.compare_digest(api_password,
|
||||
request.query[DATA_API_PASSWORD])):
|
||||
authenticated = True
|
||||
|
||||
elif (hdrs.AUTHORIZATION in request.headers and
|
||||
validate_authorization_header(api_password, request)):
|
||||
authenticated = True
|
||||
|
||||
elif _is_trusted_ip(request, trusted_networks):
|
||||
authenticated = True
|
||||
|
||||
request[KEY_AUTHENTICATED] = authenticated
|
||||
return (yield from handler(request))
|
||||
|
||||
# Check authentication
|
||||
authenticated = False
|
||||
@asyncio.coroutine
|
||||
def auth_startup(app):
|
||||
"""Initialize auth middleware when app starts up."""
|
||||
app.middlewares.append(auth_middleware)
|
||||
|
||||
if (HTTP_HEADER_HA_AUTH in request.headers and
|
||||
validate_password(
|
||||
request, request.headers[HTTP_HEADER_HA_AUTH])):
|
||||
# A valid auth header has been set
|
||||
authenticated = True
|
||||
|
||||
elif (DATA_API_PASSWORD in request.query and
|
||||
validate_password(request, request.query[DATA_API_PASSWORD])):
|
||||
authenticated = True
|
||||
|
||||
elif (hdrs.AUTHORIZATION in request.headers and
|
||||
validate_authorization_header(request)):
|
||||
authenticated = True
|
||||
|
||||
elif is_trusted_ip(request):
|
||||
authenticated = True
|
||||
|
||||
request[KEY_AUTHENTICATED] = authenticated
|
||||
return (yield from handler(request))
|
||||
app.on_startup.append(auth_startup)
|
||||
|
||||
|
||||
def is_trusted_ip(request):
|
||||
def _is_trusted_ip(request, trusted_networks):
|
||||
"""Test if request is from a trusted ip."""
|
||||
ip_addr = get_real_ip(request)
|
||||
ip_addr = request[KEY_REAL_IP]
|
||||
|
||||
return ip_addr and any(
|
||||
return any(
|
||||
ip_addr in trusted_network for trusted_network
|
||||
in request.app[KEY_TRUSTED_NETWORKS])
|
||||
in trusted_networks)
|
||||
|
||||
|
||||
def validate_password(request, api_password):
|
||||
@ -64,7 +75,7 @@ def validate_password(request, api_password):
|
||||
api_password, request.app['hass'].http.api_password)
|
||||
|
||||
|
||||
def validate_authorization_header(request):
|
||||
def validate_authorization_header(api_password, request):
|
||||
"""Test an authorization header if valid password."""
|
||||
if hdrs.AUTHORIZATION not in request.headers:
|
||||
return False
|
||||
@ -80,4 +91,4 @@ def validate_authorization_header(request):
|
||||
if username != 'homeassistant':
|
||||
return False
|
||||
|
||||
return validate_password(request, password)
|
||||
return hmac.compare_digest(api_password, password)
|
||||
|
@ -10,18 +10,20 @@ from aiohttp.web import middleware
|
||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util.yaml import dump
|
||||
from .const import (
|
||||
KEY_BANS_ENABLED, KEY_BANNED_IPS, KEY_LOGIN_THRESHOLD,
|
||||
KEY_FAILED_LOGIN_ATTEMPTS)
|
||||
from .util import get_real_ip
|
||||
from .const import KEY_REAL_IP
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
KEY_BANNED_IPS = 'ha_banned_ips'
|
||||
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
|
||||
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
|
||||
|
||||
NOTIFICATION_ID_BAN = 'ip-ban'
|
||||
NOTIFICATION_ID_LOGIN = 'http-login'
|
||||
|
||||
@ -33,21 +35,31 @@ SCHEMA_IP_BAN_ENTRY = vol.Schema({
|
||||
})
|
||||
|
||||
|
||||
@callback
|
||||
def setup_bans(hass, app, login_threshold):
|
||||
"""Create IP Ban middleware for the app."""
|
||||
@asyncio.coroutine
|
||||
def ban_startup(app):
|
||||
"""Initialize bans when app starts up."""
|
||||
app.middlewares.append(ban_middleware)
|
||||
app[KEY_BANNED_IPS] = yield from hass.async_add_job(
|
||||
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||
|
||||
app.on_startup.append(ban_startup)
|
||||
|
||||
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def ban_middleware(request, handler):
|
||||
"""IP Ban middleware."""
|
||||
if not request.app[KEY_BANS_ENABLED]:
|
||||
if KEY_BANNED_IPS not in request.app:
|
||||
_LOGGER.error('IP Ban middleware loaded but banned IPs not loaded')
|
||||
return (yield from handler(request))
|
||||
|
||||
if KEY_BANNED_IPS not in request.app:
|
||||
hass = request.app['hass']
|
||||
request.app[KEY_BANNED_IPS] = yield from hass.async_add_job(
|
||||
load_ip_bans_config, hass.config.path(IP_BANS_FILE))
|
||||
|
||||
# Verify if IP is not banned
|
||||
ip_address_ = get_real_ip(request)
|
||||
|
||||
ip_address_ = request[KEY_REAL_IP]
|
||||
is_banned = any(ip_ban.ip_address == ip_address_
|
||||
for ip_ban in request.app[KEY_BANNED_IPS])
|
||||
|
||||
@ -64,7 +76,7 @@ def ban_middleware(request, handler):
|
||||
@asyncio.coroutine
|
||||
def process_wrong_login(request):
|
||||
"""Process a wrong login attempt."""
|
||||
remote_addr = get_real_ip(request)
|
||||
remote_addr = request[KEY_REAL_IP]
|
||||
|
||||
msg = ('Login attempt or request with invalid authentication '
|
||||
'from {}'.format(remote_addr))
|
||||
@ -73,13 +85,11 @@ def process_wrong_login(request):
|
||||
request.app['hass'], msg, 'Login attempt failed',
|
||||
NOTIFICATION_ID_LOGIN)
|
||||
|
||||
if (not request.app[KEY_BANS_ENABLED] or
|
||||
# Check if ban middleware is loaded
|
||||
if (KEY_BANNED_IPS not in request.app or
|
||||
request.app[KEY_LOGIN_THRESHOLD] < 1):
|
||||
return
|
||||
|
||||
if KEY_FAILED_LOGIN_ATTEMPTS not in request.app:
|
||||
request.app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||
|
||||
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
|
||||
|
||||
if (request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] >
|
||||
|
@ -1,11 +1,3 @@
|
||||
"""HTTP specific constants."""
|
||||
KEY_AUTHENTICATED = 'ha_authenticated'
|
||||
KEY_USE_X_FORWARDED_FOR = 'ha_use_x_forwarded_for'
|
||||
KEY_TRUSTED_NETWORKS = 'ha_trusted_networks'
|
||||
KEY_REAL_IP = 'ha_real_ip'
|
||||
KEY_BANS_ENABLED = 'ha_bans_enabled'
|
||||
KEY_BANNED_IPS = 'ha_banned_ips'
|
||||
KEY_FAILED_LOGIN_ATTEMPTS = 'ha_failed_login_attempts'
|
||||
KEY_LOGIN_THRESHOLD = 'ha_login_threshold'
|
||||
|
||||
HTTP_HEADER_X_FORWARDED_FOR = 'X-Forwarded-For'
|
||||
|
43
homeassistant/components/http/cors.py
Normal file
43
homeassistant/components/http/cors.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""Provide cors support for the HTTP component."""
|
||||
import asyncio
|
||||
|
||||
from aiohttp.hdrs import ACCEPT, ORIGIN, CONTENT_TYPE
|
||||
|
||||
from homeassistant.const import (
|
||||
HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_HA_AUTH)
|
||||
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
|
||||
ALLOWED_CORS_HEADERS = [
|
||||
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
|
||||
HTTP_HEADER_HA_AUTH]
|
||||
|
||||
|
||||
@callback
|
||||
def setup_cors(app, origins):
|
||||
"""Setup cors."""
|
||||
import aiohttp_cors
|
||||
|
||||
cors = aiohttp_cors.setup(app, defaults={
|
||||
host: aiohttp_cors.ResourceOptions(
|
||||
allow_headers=ALLOWED_CORS_HEADERS,
|
||||
allow_methods='*',
|
||||
) for host in origins
|
||||
})
|
||||
|
||||
@asyncio.coroutine
|
||||
def cors_startup(app):
|
||||
"""Initialize cors when app starts up."""
|
||||
cors_added = set()
|
||||
|
||||
for route in list(app.router.routes()):
|
||||
if hasattr(route, 'resource'):
|
||||
route = route.resource
|
||||
if route in cors_added:
|
||||
continue
|
||||
cors.add(route)
|
||||
cors_added.add(route)
|
||||
|
||||
app.on_startup.append(cors_startup)
|
35
homeassistant/components/http/real_ip.py
Normal file
35
homeassistant/components/http/real_ip.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Middleware to fetch real IP."""
|
||||
import asyncio
|
||||
from ipaddress import ip_address
|
||||
|
||||
from aiohttp.web import middleware
|
||||
from aiohttp.hdrs import X_FORWARDED_FOR
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
from .const import KEY_REAL_IP
|
||||
|
||||
|
||||
@callback
|
||||
def setup_real_ip(app, use_x_forwarded_for):
|
||||
"""Create IP Ban middleware for the app."""
|
||||
@middleware
|
||||
@asyncio.coroutine
|
||||
def real_ip_middleware(request, handler):
|
||||
"""Real IP middleware."""
|
||||
if (use_x_forwarded_for and
|
||||
X_FORWARDED_FOR in request.headers):
|
||||
request[KEY_REAL_IP] = ip_address(
|
||||
request.headers.get(X_FORWARDED_FOR).split(',')[0])
|
||||
else:
|
||||
request[KEY_REAL_IP] = \
|
||||
ip_address(request.transport.get_extra_info('peername')[0])
|
||||
|
||||
return (yield from handler(request))
|
||||
|
||||
@asyncio.coroutine
|
||||
def app_startup(app):
|
||||
"""Initialize bans when app starts up."""
|
||||
app.middlewares.append(real_ip_middleware)
|
||||
|
||||
app.on_startup.append(app_startup)
|
@ -1,25 +0,0 @@
|
||||
"""HTTP utilities."""
|
||||
from ipaddress import ip_address
|
||||
|
||||
from .const import (
|
||||
KEY_REAL_IP, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
||||
|
||||
|
||||
def get_real_ip(request):
|
||||
"""Get IP address of client."""
|
||||
if KEY_REAL_IP in request:
|
||||
return request[KEY_REAL_IP]
|
||||
|
||||
if (request.app.get(KEY_USE_X_FORWARDED_FOR) and
|
||||
HTTP_HEADER_X_FORWARDED_FOR in request.headers):
|
||||
request[KEY_REAL_IP] = ip_address(
|
||||
request.headers.get(HTTP_HEADER_X_FORWARDED_FOR).split(',')[0])
|
||||
else:
|
||||
peername = request.transport.get_extra_info('peername')
|
||||
|
||||
if peername:
|
||||
request[KEY_REAL_IP] = ip_address(peername[0])
|
||||
else:
|
||||
request[KEY_REAL_IP] = None
|
||||
|
||||
return request[KEY_REAL_IP]
|
@ -12,7 +12,7 @@ import logging
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http.util import get_real_ip
|
||||
from homeassistant.components.http.const import KEY_REAL_IP
|
||||
from homeassistant.components.telegram_bot import (
|
||||
CONF_ALLOWED_CHAT_IDS, BaseTelegramBotEntity, PLATFORM_SCHEMA)
|
||||
from homeassistant.const import (
|
||||
@ -110,7 +110,7 @@ class BotPushReceiver(HomeAssistantView, BaseTelegramBotEntity):
|
||||
@asyncio.coroutine
|
||||
def post(self, request):
|
||||
"""Accept the POST from telegram."""
|
||||
real_ip = get_real_ip(request)
|
||||
real_ip = request[KEY_REAL_IP]
|
||||
if not any(real_ip in net for net in self.trusted_networks):
|
||||
_LOGGER.warning("Access denied from %s", real_ip)
|
||||
return self.json_message('Access denied', HTTP_UNAUTHORIZED)
|
||||
|
@ -9,8 +9,6 @@ import logging
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from homeassistant import core as ha, loader
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.config import async_process_component_config
|
||||
@ -25,9 +23,6 @@ from homeassistant.const import (
|
||||
EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE,
|
||||
ATTR_DISCOVERED, SERVER_PORT, EVENT_HOMEASSISTANT_CLOSE)
|
||||
from homeassistant.components import mqtt, recorder
|
||||
from homeassistant.components.http.auth import auth_middleware
|
||||
from homeassistant.components.http.const import (
|
||||
KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS)
|
||||
from homeassistant.util.async import (
|
||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||
|
||||
@ -262,35 +257,6 @@ def mock_state_change_event(hass, new_state, old_state=None):
|
||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data)
|
||||
|
||||
|
||||
def mock_http_component(hass, api_password=None):
|
||||
"""Mock the HTTP component."""
|
||||
hass.http = MagicMock(api_password=api_password)
|
||||
mock_component(hass, 'http')
|
||||
hass.http.views = {}
|
||||
|
||||
def mock_register_view(view):
|
||||
"""Store registered view."""
|
||||
if isinstance(view, type):
|
||||
# Instantiate the view, if needed
|
||||
view = view()
|
||||
|
||||
hass.http.views[view.name] = view
|
||||
|
||||
hass.http.register_view = mock_register_view
|
||||
|
||||
|
||||
def mock_http_component_app(hass, api_password=None):
|
||||
"""Create an aiohttp.web.Application instance for testing."""
|
||||
if 'http' not in hass.config.components:
|
||||
mock_http_component(hass, api_password)
|
||||
app = web.Application(middlewares=[auth_middleware])
|
||||
app['hass'] = hass
|
||||
app[KEY_USE_X_FORWARDED_FOR] = False
|
||||
app[KEY_BANS_ENABLED] = False
|
||||
app[KEY_TRUSTED_NETWORKS] = []
|
||||
return app
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def async_mock_mqtt_component(hass, config=None):
|
||||
"""Mock the MQTT component."""
|
||||
|
@ -9,7 +9,7 @@ from uvcclient import nvr
|
||||
|
||||
from homeassistant.setup import setup_component
|
||||
from homeassistant.components.camera import uvc
|
||||
from tests.common import get_test_home_assistant, mock_http_component
|
||||
from tests.common import get_test_home_assistant
|
||||
|
||||
|
||||
class TestUVCSetup(unittest.TestCase):
|
||||
@ -18,7 +18,6 @@ class TestUVCSetup(unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
mock_http_component(self.hass)
|
||||
|
||||
def tearDown(self):
|
||||
"""Stop everything that was started."""
|
||||
|
@ -14,7 +14,7 @@ def test_setup_check_env_prevents_load(hass, loop):
|
||||
with patch.dict(os.environ, clear=True), \
|
||||
patch.object(config, 'SECTIONS', ['hassbian']), \
|
||||
patch('homeassistant.components.http.'
|
||||
'HomeAssistantWSGI.register_view') as reg_view:
|
||||
'HomeAssistantHTTP.register_view') as reg_view:
|
||||
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
||||
assert 'config' in hass.config.components
|
||||
assert reg_view.called is False
|
||||
@ -25,7 +25,7 @@ def test_setup_check_env_works(hass, loop):
|
||||
with patch.dict(os.environ, {'FORCE_HASSBIAN': '1'}), \
|
||||
patch.object(config, 'SECTIONS', ['hassbian']), \
|
||||
patch('homeassistant.components.http.'
|
||||
'HomeAssistantWSGI.register_view') as reg_view:
|
||||
'HomeAssistantHTTP.register_view') as reg_view:
|
||||
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
||||
assert 'config' in hass.config.components
|
||||
assert len(reg_view.mock_calls) == 2
|
||||
|
@ -2,19 +2,11 @@
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import EVENT_COMPONENT_LOADED
|
||||
from homeassistant.setup import async_setup_component, ATTR_COMPONENT
|
||||
from homeassistant.components import config
|
||||
|
||||
from tests.common import mock_http_component, mock_coro, mock_component
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def stub_http(hass):
|
||||
"""Stub the HTTP component."""
|
||||
mock_http_component(hass)
|
||||
from tests.common import mock_coro, mock_component
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
|
@ -3,28 +3,30 @@ import asyncio
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.bootstrap import async_setup_component
|
||||
from homeassistant.components import config
|
||||
|
||||
from homeassistant.components.zwave import DATA_NETWORK, const
|
||||
from homeassistant.components.config.zwave import (
|
||||
ZWaveNodeValueView, ZWaveNodeGroupView, ZWaveNodeConfigView,
|
||||
ZWaveUserCodeView, ZWaveConfigWriteView)
|
||||
from tests.common import mock_http_component_app
|
||||
from tests.mock.zwave import MockNode, MockValue, MockEntityValues
|
||||
|
||||
|
||||
VIEW_NAME = 'api:config:zwave:device_config'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_device_config(hass, test_client):
|
||||
"""Test getting device config."""
|
||||
@pytest.fixture
|
||||
def client(loop, hass, test_client):
|
||||
"""Client to communicate with Z-Wave config views."""
|
||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
||||
yield from async_setup_component(hass, 'config', {})
|
||||
loop.run_until_complete(async_setup_component(hass, 'config', {}))
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
return loop.run_until_complete(test_client(hass.http.app))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_device_config(client):
|
||||
"""Test getting device config."""
|
||||
def mock_read(path):
|
||||
"""Mock reading data."""
|
||||
return {
|
||||
@ -47,13 +49,8 @@ def test_get_device_config(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_update_device_config(hass, test_client):
|
||||
def test_update_device_config(client):
|
||||
"""Test updating device config."""
|
||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
||||
yield from async_setup_component(hass, 'config', {})
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
|
||||
orig_data = {
|
||||
'hello.beer': {
|
||||
'ignored': True,
|
||||
@ -90,13 +87,8 @@ def test_update_device_config(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_update_device_config_invalid_key(hass, test_client):
|
||||
def test_update_device_config_invalid_key(client):
|
||||
"""Test updating device config."""
|
||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
||||
yield from async_setup_component(hass, 'config', {})
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
|
||||
resp = yield from client.post(
|
||||
'/api/config/zwave/device_config/invalid_entity', data=json.dumps({
|
||||
'polling_intensity': 2
|
||||
@ -106,13 +98,8 @@ def test_update_device_config_invalid_key(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_update_device_config_invalid_data(hass, test_client):
|
||||
def test_update_device_config_invalid_data(client):
|
||||
"""Test updating device config."""
|
||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
||||
yield from async_setup_component(hass, 'config', {})
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
|
||||
resp = yield from client.post(
|
||||
'/api/config/zwave/device_config/hello.beer', data=json.dumps({
|
||||
'invalid_option': 2
|
||||
@ -122,13 +109,8 @@ def test_update_device_config_invalid_data(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_update_device_config_invalid_json(hass, test_client):
|
||||
def test_update_device_config_invalid_json(client):
|
||||
"""Test updating device config."""
|
||||
with patch.object(config, 'SECTIONS', ['zwave']):
|
||||
yield from async_setup_component(hass, 'config', {})
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
|
||||
resp = yield from client.post(
|
||||
'/api/config/zwave/device_config/hello.beer', data='not json')
|
||||
|
||||
@ -136,11 +118,8 @@ def test_update_device_config_invalid_json(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_values(hass, test_client):
|
||||
def test_get_values(hass, client):
|
||||
"""Test getting values on node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeValueView().register(app.router)
|
||||
|
||||
node = MockNode(node_id=1)
|
||||
value = MockValue(value_id=123456, node=node, label='Test Label',
|
||||
instance=1, index=2, poll_intensity=4)
|
||||
@ -150,8 +129,6 @@ def test_get_values(hass, test_client):
|
||||
values2 = MockEntityValues(primary=value2)
|
||||
hass.data[const.DATA_ENTITY_VALUES] = [values, values2]
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/values/1')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -168,11 +145,8 @@ def test_get_values(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_groups(hass, test_client):
|
||||
def test_get_groups(hass, client):
|
||||
"""Test getting groupdata on node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeGroupView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=2)
|
||||
node.groups.associations = 'assoc'
|
||||
@ -182,8 +156,6 @@ def test_get_groups(hass, test_client):
|
||||
node.groups = {1: node.groups}
|
||||
network.nodes = {2: node}
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/groups/2')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -200,18 +172,13 @@ def test_get_groups(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_groups_nogroups(hass, test_client):
|
||||
def test_get_groups_nogroups(hass, client):
|
||||
"""Test getting groupdata on node with no groups."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeGroupView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=2)
|
||||
|
||||
network.nodes = {2: node}
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/groups/2')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -221,16 +188,11 @@ def test_get_groups_nogroups(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_groups_nonode(hass, test_client):
|
||||
def test_get_groups_nonode(hass, client):
|
||||
"""Test getting groupdata on nonexisting node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeGroupView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
network.nodes = {1: 1, 5: 5}
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/groups/2')
|
||||
|
||||
assert resp.status == 404
|
||||
@ -240,11 +202,8 @@ def test_get_groups_nonode(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_config(hass, test_client):
|
||||
def test_get_config(hass, client):
|
||||
"""Test getting config on node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeConfigView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=2)
|
||||
value = MockValue(
|
||||
@ -261,8 +220,6 @@ def test_get_config(hass, test_client):
|
||||
network.nodes = {2: node}
|
||||
node.get_values.return_value = node.values
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/config/2')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -278,19 +235,14 @@ def test_get_config(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_config_noconfig_node(hass, test_client):
|
||||
def test_get_config_noconfig_node(hass, client):
|
||||
"""Test getting config on node without config."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeConfigView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=2)
|
||||
|
||||
network.nodes = {2: node}
|
||||
node.get_values.return_value = node.values
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/config/2')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -300,16 +252,11 @@ def test_get_config_noconfig_node(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_config_nonode(hass, test_client):
|
||||
def test_get_config_nonode(hass, client):
|
||||
"""Test getting config on nonexisting node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveNodeConfigView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
network.nodes = {1: 1, 5: 5}
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/config/2')
|
||||
|
||||
assert resp.status == 404
|
||||
@ -319,16 +266,11 @@ def test_get_config_nonode(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_usercodes_nonode(hass, test_client):
|
||||
def test_get_usercodes_nonode(hass, client):
|
||||
"""Test getting usercodes on nonexisting node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveUserCodeView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
network.nodes = {1: 1, 5: 5}
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/usercodes/2')
|
||||
|
||||
assert resp.status == 404
|
||||
@ -338,11 +280,8 @@ def test_get_usercodes_nonode(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_usercodes(hass, test_client):
|
||||
def test_get_usercodes(hass, client):
|
||||
"""Test getting usercodes on node."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveUserCodeView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=18,
|
||||
command_classes=[const.COMMAND_CLASS_USER_CODE])
|
||||
@ -356,8 +295,6 @@ def test_get_usercodes(hass, test_client):
|
||||
network.nodes = {18: node}
|
||||
node.get_values.return_value = node.values
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/usercodes/18')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -369,19 +306,14 @@ def test_get_usercodes(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_usercode_nousercode_node(hass, test_client):
|
||||
def test_get_usercode_nousercode_node(hass, client):
|
||||
"""Test getting usercodes on node without usercodes."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveUserCodeView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=18)
|
||||
|
||||
network.nodes = {18: node}
|
||||
node.get_values.return_value = node.values
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/usercodes/18')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -391,11 +323,8 @@ def test_get_usercode_nousercode_node(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_usercodes_no_genreuser(hass, test_client):
|
||||
def test_get_usercodes_no_genreuser(hass, client):
|
||||
"""Test getting usercodes on node missing genre user."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveUserCodeView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
node = MockNode(node_id=18,
|
||||
command_classes=[const.COMMAND_CLASS_USER_CODE])
|
||||
@ -409,8 +338,6 @@ def test_get_usercodes_no_genreuser(hass, test_client):
|
||||
network.nodes = {18: node}
|
||||
node.get_values.return_value = node.values
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/api/zwave/usercodes/18')
|
||||
|
||||
assert resp.status == 200
|
||||
@ -420,13 +347,8 @@ def test_get_usercodes_no_genreuser(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_save_config_no_network(hass, test_client):
|
||||
def test_save_config_no_network(hass, client):
|
||||
"""Test saving configuration without network data."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveConfigWriteView().register(app.router)
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.post('/api/zwave/saveconfig')
|
||||
|
||||
assert resp.status == 404
|
||||
@ -435,15 +357,10 @@ def test_save_config_no_network(hass, test_client):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_save_config(hass, test_client):
|
||||
def test_save_config(hass, client):
|
||||
"""Test saving configuration."""
|
||||
app = mock_http_component_app(hass)
|
||||
ZWaveConfigWriteView().register(app.router)
|
||||
|
||||
network = hass.data[DATA_NETWORK] = MagicMock()
|
||||
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.post('/api/zwave/saveconfig')
|
||||
|
||||
assert resp.status == 200
|
||||
|
@ -5,11 +5,10 @@ import logging
|
||||
from unittest.mock import patch, MagicMock
|
||||
import aioautomatic
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.device_tracker.automatic import (
|
||||
async_setup_scanner)
|
||||
|
||||
from tests.common import mock_http_component
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -23,8 +22,7 @@ def test_invalid_credentials(
|
||||
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
|
||||
mock_create_session, hass):
|
||||
"""Test with invalid credentials."""
|
||||
mock_http_component(hass)
|
||||
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
|
||||
mock_json_load.return_value = {'refresh_token': 'bad_token'}
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -59,8 +57,7 @@ def test_valid_credentials(
|
||||
mock_open, mock_isfile, mock_makedirs, mock_json_dump, mock_json_load,
|
||||
mock_ws_connect, mock_create_session, hass):
|
||||
"""Test with valid credentials."""
|
||||
mock_http_component(hass)
|
||||
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'http', {}))
|
||||
mock_json_load.return_value = {'refresh_token': 'good_token'}
|
||||
|
||||
session = MagicMock()
|
||||
|
@ -1 +1,38 @@
|
||||
"""Tests for the HTTP component."""
|
||||
import asyncio
|
||||
from ipaddress import ip_address
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from homeassistant.components.http.const import KEY_REAL_IP
|
||||
|
||||
|
||||
def mock_real_ip(app):
|
||||
"""Inject middleware to mock real IP.
|
||||
|
||||
Returns a function to set the real IP.
|
||||
"""
|
||||
ip_to_mock = None
|
||||
|
||||
def set_ip_to_mock(value):
|
||||
nonlocal ip_to_mock
|
||||
ip_to_mock = value
|
||||
|
||||
@asyncio.coroutine
|
||||
@web.middleware
|
||||
def mock_real_ip(request, handler):
|
||||
"""Mock Real IP middleware."""
|
||||
nonlocal ip_to_mock
|
||||
|
||||
request[KEY_REAL_IP] = ip_address(ip_to_mock)
|
||||
|
||||
return (yield from handler(request))
|
||||
|
||||
@asyncio.coroutine
|
||||
def real_ip_startup(app):
|
||||
"""Startup of real ip."""
|
||||
app.middlewares.insert(0, mock_real_ip)
|
||||
|
||||
app.on_startup.append(real_ip_startup)
|
||||
|
||||
return set_ip_to_mock
|
||||
|
@ -1,195 +1,156 @@
|
||||
"""The tests for the Home Assistant HTTP component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from ipaddress import ip_address, ip_network
|
||||
from ipaddress import ip_network
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import BasicAuth, web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
import pytest
|
||||
|
||||
from homeassistant import const
|
||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.components.http as http
|
||||
from homeassistant.components.http.const import (
|
||||
KEY_TRUSTED_NETWORKS, KEY_USE_X_FORWARDED_FOR, HTTP_HEADER_X_FORWARDED_FOR)
|
||||
from homeassistant.components.http.auth import setup_auth
|
||||
from homeassistant.components.http.real_ip import setup_real_ip
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
|
||||
from . import mock_real_ip
|
||||
|
||||
API_PASSWORD = 'test1234'
|
||||
|
||||
# Don't add 127.0.0.1/::1 as trusted, as it may interfere with other test cases
|
||||
TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
|
||||
'FD01:DB8::1']
|
||||
TRUSTED_NETWORKS = [
|
||||
ip_network('192.0.2.0/24'),
|
||||
ip_network('2001:DB8:ABCD::/48'),
|
||||
ip_network('100.64.0.1'),
|
||||
ip_network('FD01:DB8::1'),
|
||||
]
|
||||
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
|
||||
'2001:DB8:ABCD::1']
|
||||
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_client(hass, test_client):
|
||||
"""Start the Hass HTTP component."""
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
|
||||
'http': {
|
||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||
}
|
||||
}))
|
||||
return hass.loop.run_until_complete(test_client(hass.http.app))
|
||||
@asyncio.coroutine
|
||||
def mock_handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
if not request[KEY_AUTHENTICATED]:
|
||||
raise HTTPUnauthorized
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_trusted_networks(hass, mock_api_client):
|
||||
"""Mock trusted networks."""
|
||||
hass.http.app[KEY_TRUSTED_NETWORKS] = [
|
||||
ip_network(trusted_network)
|
||||
for trusted_network in TRUSTED_NETWORKS]
|
||||
def app():
|
||||
"""Fixture to setup a web.Application."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_real_ip(app, False)
|
||||
return app
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_denied_without_password(mock_api_client):
|
||||
def test_auth_middleware_loaded_by_default(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_auth') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
'http': {}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_without_password(app, test_client):
|
||||
"""Test access without password."""
|
||||
resp = yield from mock_api_client.get(const.URL_API)
|
||||
setup_auth(app, [], None)
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_password_in_header(app, test_client):
|
||||
"""Test access with password in URL."""
|
||||
setup_auth(app, [], API_PASSWORD)
|
||||
client = yield from test_client(app)
|
||||
|
||||
req = yield from client.get(
|
||||
'/', headers={HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
assert req.status == 200
|
||||
|
||||
req = yield from client.get(
|
||||
'/', headers={HTTP_HEADER_HA_AUTH: 'wrong-pass'})
|
||||
assert req.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_password_in_query(app, test_client):
|
||||
"""Test access without password."""
|
||||
setup_auth(app, [], API_PASSWORD)
|
||||
client = yield from test_client(app)
|
||||
|
||||
resp = yield from client.get('/', params={
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
assert resp.status == 200
|
||||
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_denied_with_wrong_password_in_header(mock_api_client):
|
||||
"""Test access with wrong password."""
|
||||
resp = yield from mock_api_client.get(const.URL_API, headers={
|
||||
const.HTTP_HEADER_HA_AUTH: 'wrongpassword'
|
||||
resp = yield from client.get('/', params={
|
||||
'api_password': 'wrong-password'
|
||||
})
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_denied_with_x_forwarded_for(hass, mock_api_client,
|
||||
mock_trusted_networks):
|
||||
"""Test access denied through the X-Forwarded-For http header."""
|
||||
hass.http.use_x_forwarded_for = True
|
||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||
resp = yield from mock_api_client.get(const.URL_API, headers={
|
||||
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
||||
|
||||
assert resp.status == 401, \
|
||||
"{} shouldn't be trusted".format(remote_addr)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_denied_with_untrusted_ip(mock_api_client,
|
||||
mock_trusted_networks):
|
||||
"""Test access with an untrusted ip address."""
|
||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||
with patch('homeassistant.components.http.'
|
||||
'util.get_real_ip',
|
||||
return_value=ip_address(remote_addr)):
|
||||
resp = yield from mock_api_client.get(
|
||||
const.URL_API, params={'api_password': ''})
|
||||
|
||||
assert resp.status == 401, \
|
||||
"{} shouldn't be trusted".format(remote_addr)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_password_in_header(mock_api_client, caplog):
|
||||
"""Test access with password in URL."""
|
||||
# Hide logging from requests package that we use to test logging
|
||||
req = yield from mock_api_client.get(
|
||||
const.URL_API, headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
|
||||
assert req.status == 200
|
||||
|
||||
logs = caplog.text
|
||||
|
||||
assert const.URL_API in logs
|
||||
assert API_PASSWORD not in logs
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_denied_with_wrong_password_in_url(mock_api_client):
|
||||
"""Test access with wrong password."""
|
||||
resp = yield from mock_api_client.get(
|
||||
const.URL_API, params={'api_password': 'wrongpassword'})
|
||||
|
||||
assert resp.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_password_in_url(mock_api_client, caplog):
|
||||
"""Test access with password in URL."""
|
||||
req = yield from mock_api_client.get(
|
||||
const.URL_API, params={'api_password': API_PASSWORD})
|
||||
|
||||
assert req.status == 200
|
||||
|
||||
logs = caplog.text
|
||||
|
||||
assert const.URL_API in logs
|
||||
assert API_PASSWORD not in logs
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_granted_with_x_forwarded_for(hass, mock_api_client, caplog,
|
||||
mock_trusted_networks):
|
||||
"""Test access denied through the X-Forwarded-For http header."""
|
||||
hass.http.app[KEY_USE_X_FORWARDED_FOR] = True
|
||||
for remote_addr in TRUSTED_ADDRESSES:
|
||||
resp = yield from mock_api_client.get(const.URL_API, headers={
|
||||
HTTP_HEADER_X_FORWARDED_FOR: remote_addr})
|
||||
|
||||
assert resp.status == 200, \
|
||||
"{} should be trusted".format(remote_addr)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_granted_with_trusted_ip(mock_api_client, caplog,
|
||||
mock_trusted_networks):
|
||||
"""Test access with trusted addresses."""
|
||||
for remote_addr in TRUSTED_ADDRESSES:
|
||||
with patch('homeassistant.components.http.'
|
||||
'auth.get_real_ip',
|
||||
return_value=ip_address(remote_addr)):
|
||||
resp = yield from mock_api_client.get(
|
||||
const.URL_API, params={'api_password': ''})
|
||||
|
||||
assert resp.status == 200, \
|
||||
'{} should be trusted'.format(remote_addr)
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_basic_auth_works(mock_api_client, caplog):
|
||||
def test_basic_auth_works(app, test_client):
|
||||
"""Test access with basic authentication."""
|
||||
req = yield from mock_api_client.get(
|
||||
const.URL_API,
|
||||
auth=aiohttp.BasicAuth('homeassistant', API_PASSWORD))
|
||||
setup_auth(app, [], API_PASSWORD)
|
||||
client = yield from test_client(app)
|
||||
|
||||
req = yield from client.get(
|
||||
'/',
|
||||
auth=BasicAuth('homeassistant', API_PASSWORD))
|
||||
assert req.status == 200
|
||||
assert const.URL_API in caplog.text
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_basic_auth_username_homeassistant(mock_api_client, caplog):
|
||||
"""Test access with basic auth requires username homeassistant."""
|
||||
req = yield from mock_api_client.get(
|
||||
const.URL_API,
|
||||
auth=aiohttp.BasicAuth('wrong_username', API_PASSWORD))
|
||||
|
||||
req = yield from client.get(
|
||||
'/',
|
||||
auth=BasicAuth('wrong_username', API_PASSWORD))
|
||||
assert req.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_basic_auth_wrong_password(mock_api_client, caplog):
|
||||
"""Test access with basic auth not allowed with wrong password."""
|
||||
req = yield from mock_api_client.get(
|
||||
const.URL_API,
|
||||
auth=aiohttp.BasicAuth('homeassistant', 'wrong password'))
|
||||
|
||||
req = yield from client.get(
|
||||
'/',
|
||||
auth=BasicAuth('homeassistant', 'wrong password'))
|
||||
assert req.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_authorization_header_must_be_basic_type(mock_api_client, caplog):
|
||||
"""Test only basic authorization is allowed for auth header."""
|
||||
req = yield from mock_api_client.get(
|
||||
const.URL_API,
|
||||
req = yield from client.get(
|
||||
'/',
|
||||
headers={
|
||||
'authorization': 'NotBasic abcdefg'
|
||||
})
|
||||
|
||||
assert req.status == 401
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_with_trusted_ip(test_client):
|
||||
"""Test access with an untrusted ip address."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
|
||||
setup_auth(app, TRUSTED_NETWORKS, 'some-pass')
|
||||
|
||||
set_mock_ip = mock_real_ip(app)
|
||||
client = yield from test_client(app)
|
||||
|
||||
for remote_addr in UNTRUSTED_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 401, \
|
||||
"{} shouldn't be trusted".format(remote_addr)
|
||||
|
||||
for remote_addr in TRUSTED_ADDRESSES:
|
||||
set_mock_ip(remote_addr)
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 200, \
|
||||
"{} should be trusted".format(remote_addr)
|
||||
|
@ -1,91 +1,96 @@
|
||||
"""The tests for the Home Assistant HTTP component."""
|
||||
# pylint: disable=protected-access
|
||||
import asyncio
|
||||
from ipaddress import ip_address
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
|
||||
from homeassistant import const
|
||||
from homeassistant.setup import async_setup_component
|
||||
import homeassistant.components.http as http
|
||||
from homeassistant.components.http.const import (
|
||||
KEY_BANS_ENABLED, KEY_LOGIN_THRESHOLD, KEY_BANNED_IPS)
|
||||
from homeassistant.components.http.ban import IpBan, IP_BANS_FILE
|
||||
from homeassistant.components.http.ban import (
|
||||
IpBan, IP_BANS_FILE, setup_bans, KEY_BANNED_IPS)
|
||||
|
||||
from . import mock_real_ip
|
||||
|
||||
API_PASSWORD = 'test1234'
|
||||
BANNED_IPS = ['200.201.202.203', '100.64.0.2']
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_client(hass, test_client):
|
||||
"""Start the Hass HTTP component."""
|
||||
hass.loop.run_until_complete(async_setup_component(hass, 'api', {
|
||||
'http': {
|
||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||
}
|
||||
}))
|
||||
hass.http.app[KEY_BANNED_IPS] = [IpBan(banned_ip) for banned_ip
|
||||
in BANNED_IPS]
|
||||
return hass.loop.run_until_complete(test_client(hass.http.app))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_from_banned_ip(hass, mock_api_client):
|
||||
def test_access_from_banned_ip(hass, test_client):
|
||||
"""Test accessing to server from banned IP. Both trusted and not."""
|
||||
hass.http.app[KEY_BANS_ENABLED] = True
|
||||
app = web.Application()
|
||||
setup_bans(hass, app, 5)
|
||||
set_real_ip = mock_real_ip(app)
|
||||
|
||||
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||
return_value=[IpBan(banned_ip) for banned_ip
|
||||
in BANNED_IPS]):
|
||||
client = yield from test_client(app)
|
||||
|
||||
for remote_addr in BANNED_IPS:
|
||||
with patch('homeassistant.components.http.'
|
||||
'ban.get_real_ip',
|
||||
return_value=ip_address(remote_addr)):
|
||||
resp = yield from mock_api_client.get(
|
||||
const.URL_API)
|
||||
assert resp.status == 403
|
||||
set_real_ip(remote_addr)
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 403
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_access_from_banned_ip_when_ban_is_off(hass, mock_api_client):
|
||||
def test_ban_middleware_not_loaded_by_config(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
hass.http.app[KEY_BANS_ENABLED] = False
|
||||
for remote_addr in BANNED_IPS:
|
||||
with patch('homeassistant.components.http.'
|
||||
'ban.get_real_ip',
|
||||
return_value=ip_address(remote_addr)):
|
||||
resp = yield from mock_api_client.get(
|
||||
const.URL_API,
|
||||
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
|
||||
assert resp.status == 200
|
||||
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
http.CONF_IP_BAN_ENABLED: False,
|
||||
}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 0
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ip_bans_file_creation(hass, mock_api_client):
|
||||
def test_ban_middleware_loaded_by_default(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_bans') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
'http': {}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ip_bans_file_creation(hass, test_client):
|
||||
"""Testing if banned IP file created."""
|
||||
hass.http.app[KEY_BANS_ENABLED] = True
|
||||
hass.http.app[KEY_LOGIN_THRESHOLD] = 1
|
||||
app = web.Application()
|
||||
app['hass'] = hass
|
||||
|
||||
@asyncio.coroutine
|
||||
def unauth_handler(request):
|
||||
"""Return a mock web response."""
|
||||
raise HTTPUnauthorized
|
||||
|
||||
app.router.add_get('/', unauth_handler)
|
||||
setup_bans(hass, app, 1)
|
||||
mock_real_ip(app)("200.201.202.204")
|
||||
|
||||
with patch('homeassistant.components.http.ban.load_ip_bans_config',
|
||||
return_value=[IpBan(banned_ip) for banned_ip
|
||||
in BANNED_IPS]):
|
||||
client = yield from test_client(app)
|
||||
|
||||
m = mock_open()
|
||||
|
||||
@asyncio.coroutine
|
||||
def call_server():
|
||||
with patch('homeassistant.components.http.'
|
||||
'ban.get_real_ip',
|
||||
return_value=ip_address("200.201.202.204")):
|
||||
resp = yield from mock_api_client.get(
|
||||
const.URL_API,
|
||||
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
|
||||
return resp
|
||||
|
||||
with patch('homeassistant.components.http.ban.open', m, create=True):
|
||||
resp = yield from call_server()
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 401
|
||||
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS)
|
||||
assert m.call_count == 0
|
||||
|
||||
resp = yield from call_server()
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 401
|
||||
assert len(hass.http.app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1
|
||||
m.assert_called_once_with(hass.config.path(IP_BANS_FILE), 'a')
|
||||
|
||||
resp = yield from call_server()
|
||||
resp = yield from client.get('/')
|
||||
assert resp.status == 403
|
||||
assert m.call_count == 1
|
||||
|
104
tests/components/http/test_cors.py
Normal file
104
tests/components/http/test_cors.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Test cors for the HTTP component."""
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import (
|
||||
ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
ACCESS_CONTROL_REQUEST_HEADERS,
|
||||
ACCESS_CONTROL_REQUEST_METHOD,
|
||||
ORIGIN
|
||||
)
|
||||
import pytest
|
||||
|
||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.http.cors import setup_cors
|
||||
|
||||
|
||||
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_middleware_not_loaded_by_default(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
'http': {}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 0
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_middleware_loaded_from_config(hass):
|
||||
"""Test accessing to server from banned IP when feature is off."""
|
||||
with patch('homeassistant.components.http.setup_cors') as mock_setup:
|
||||
yield from async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'cors_allowed_origins': ['http://home-assistant.io']
|
||||
}
|
||||
})
|
||||
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_handler(request):
|
||||
"""Return if request was authenticated."""
|
||||
return web.Response(status=200)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(loop, test_client):
|
||||
"""Fixture to setup a web.Application."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_cors(app, [TRUSTED_ORIGIN])
|
||||
return loop.run_until_complete(test_client(app))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_requests(client):
|
||||
"""Test cross origin requests."""
|
||||
req = yield from client.get('/', headers={
|
||||
ORIGIN: TRUSTED_ORIGIN
|
||||
})
|
||||
assert req.status == 200
|
||||
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
|
||||
TRUSTED_ORIGIN
|
||||
|
||||
# With password in URL
|
||||
req = yield from client.get('/', params={
|
||||
'api_password': 'some-pass'
|
||||
}, headers={
|
||||
ORIGIN: TRUSTED_ORIGIN
|
||||
})
|
||||
assert req.status == 200
|
||||
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
|
||||
TRUSTED_ORIGIN
|
||||
|
||||
# With password in headers
|
||||
req = yield from client.get('/', headers={
|
||||
HTTP_HEADER_HA_AUTH: 'some-pass',
|
||||
ORIGIN: TRUSTED_ORIGIN
|
||||
})
|
||||
assert req.status == 200
|
||||
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
|
||||
TRUSTED_ORIGIN
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_cors_preflight_allowed(client):
|
||||
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
||||
req = yield from client.options('/', headers={
|
||||
ORIGIN: TRUSTED_ORIGIN,
|
||||
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
||||
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
||||
})
|
||||
|
||||
assert req.status == 200
|
||||
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN
|
||||
assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \
|
||||
HTTP_HEADER_HA_AUTH.upper()
|
@ -1,124 +1,10 @@
|
||||
"""The tests for the Home Assistant HTTP component."""
|
||||
import asyncio
|
||||
|
||||
from aiohttp.hdrs import (
|
||||
ORIGIN, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS,
|
||||
CONTENT_TYPE)
|
||||
import requests
|
||||
from tests.common import get_test_instance_port, get_test_home_assistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from homeassistant import const, setup
|
||||
import homeassistant.components.http as http
|
||||
|
||||
API_PASSWORD = 'test1234'
|
||||
SERVER_PORT = get_test_instance_port()
|
||||
HTTP_BASE = '127.0.0.1:{}'.format(SERVER_PORT)
|
||||
HTTP_BASE_URL = 'http://{}'.format(HTTP_BASE)
|
||||
HA_HEADERS = {
|
||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
||||
CONTENT_TYPE: const.CONTENT_TYPE_JSON,
|
||||
}
|
||||
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
|
||||
|
||||
hass = None
|
||||
|
||||
|
||||
def _url(path=''):
|
||||
"""Helper method to generate URLs."""
|
||||
return HTTP_BASE_URL + path
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def setUpModule():
|
||||
"""Initialize a Home Assistant server."""
|
||||
global hass
|
||||
|
||||
hass = get_test_home_assistant()
|
||||
|
||||
setup.setup_component(
|
||||
hass, http.DOMAIN, {
|
||||
http.DOMAIN: {
|
||||
http.CONF_API_PASSWORD: API_PASSWORD,
|
||||
http.CONF_SERVER_PORT: SERVER_PORT,
|
||||
http.CONF_CORS_ORIGINS: CORS_ORIGINS,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
setup.setup_component(hass, 'api')
|
||||
|
||||
# Registering static path as it caused CORS to blow up
|
||||
hass.http.register_static_path(
|
||||
'/custom_components', hass.config.path('custom_components'))
|
||||
|
||||
hass.start()
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def tearDownModule():
|
||||
"""Stop the Home Assistant server."""
|
||||
hass.stop()
|
||||
|
||||
|
||||
class TestCors:
|
||||
"""Test HTTP component."""
|
||||
|
||||
def test_cors_allowed_with_password_in_url(self):
|
||||
"""Test cross origin resource sharing with password in url."""
|
||||
req = requests.get(_url(const.URL_API),
|
||||
params={'api_password': API_PASSWORD},
|
||||
headers={ORIGIN: HTTP_BASE_URL})
|
||||
|
||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
||||
|
||||
assert req.status_code == 200
|
||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||
|
||||
def test_cors_allowed_with_password_in_header(self):
|
||||
"""Test cross origin resource sharing with password in header."""
|
||||
headers = {
|
||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
|
||||
ORIGIN: HTTP_BASE_URL
|
||||
}
|
||||
req = requests.get(_url(const.URL_API), headers=headers)
|
||||
|
||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
||||
|
||||
assert req.status_code == 200
|
||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||
|
||||
def test_cors_denied_without_origin_header(self):
|
||||
"""Test cross origin resource sharing with password in header."""
|
||||
headers = {
|
||||
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
|
||||
}
|
||||
req = requests.get(_url(const.URL_API), headers=headers)
|
||||
|
||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
||||
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
|
||||
|
||||
assert req.status_code == 200
|
||||
assert allow_origin not in req.headers
|
||||
assert allow_headers not in req.headers
|
||||
|
||||
def test_cors_preflight_allowed(self):
|
||||
"""Test cross origin resource sharing preflight (OPTIONS) request."""
|
||||
headers = {
|
||||
ORIGIN: HTTP_BASE_URL,
|
||||
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
|
||||
ACCESS_CONTROL_REQUEST_HEADERS: 'x-ha-access'
|
||||
}
|
||||
req = requests.options(_url(const.URL_API), headers=headers)
|
||||
|
||||
allow_origin = ACCESS_CONTROL_ALLOW_ORIGIN
|
||||
allow_headers = ACCESS_CONTROL_ALLOW_HEADERS
|
||||
|
||||
assert req.status_code == 200
|
||||
assert req.headers.get(allow_origin) == HTTP_BASE_URL
|
||||
assert req.headers.get(allow_headers) == \
|
||||
const.HTTP_HEADER_HA_AUTH.upper()
|
||||
|
||||
|
||||
class TestView(http.HomeAssistantView):
|
||||
"""Test the HTTP views."""
|
||||
@ -133,12 +19,12 @@ class TestView(http.HomeAssistantView):
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_view_while_running(hass, test_client):
|
||||
def test_registering_view_while_running(hass, test_client, unused_port):
|
||||
"""Test that we can register a view while the server is running."""
|
||||
yield from setup.async_setup_component(
|
||||
yield from async_setup_component(
|
||||
hass, http.DOMAIN, {
|
||||
http.DOMAIN: {
|
||||
http.CONF_SERVER_PORT: get_test_instance_port(),
|
||||
http.CONF_SERVER_PORT: unused_port(),
|
||||
}
|
||||
}
|
||||
)
|
||||
@ -151,7 +37,7 @@ def test_registering_view_while_running(hass, test_client):
|
||||
@asyncio.coroutine
|
||||
def test_api_base_url_with_domain(hass):
|
||||
"""Test setting API URL."""
|
||||
result = yield from setup.async_setup_component(hass, 'http', {
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'base_url': 'example.com'
|
||||
}
|
||||
@ -163,7 +49,7 @@ def test_api_base_url_with_domain(hass):
|
||||
@asyncio.coroutine
|
||||
def test_api_base_url_with_ip(hass):
|
||||
"""Test setting api url."""
|
||||
result = yield from setup.async_setup_component(hass, 'http', {
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'server_host': '1.1.1.1'
|
||||
}
|
||||
@ -175,7 +61,7 @@ def test_api_base_url_with_ip(hass):
|
||||
@asyncio.coroutine
|
||||
def test_api_base_url_with_ip_port(hass):
|
||||
"""Test setting api url."""
|
||||
result = yield from setup.async_setup_component(hass, 'http', {
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
'base_url': '1.1.1.1:8124'
|
||||
}
|
||||
@ -187,9 +73,34 @@ def test_api_base_url_with_ip_port(hass):
|
||||
@asyncio.coroutine
|
||||
def test_api_no_base_url(hass):
|
||||
"""Test setting api url."""
|
||||
result = yield from setup.async_setup_component(hass, 'http', {
|
||||
result = yield from async_setup_component(hass, 'http', {
|
||||
'http': {
|
||||
}
|
||||
})
|
||||
assert result
|
||||
assert hass.config.api.base_url == 'http://127.0.0.1:8123'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_not_log_password(hass, unused_port, test_client, caplog):
|
||||
"""Test access with password doesn't get logged."""
|
||||
result = yield from async_setup_component(hass, 'api', {
|
||||
'http': {
|
||||
http.CONF_SERVER_PORT: unused_port(),
|
||||
http.CONF_API_PASSWORD: 'some-pass'
|
||||
}
|
||||
})
|
||||
assert result
|
||||
|
||||
client = yield from test_client(hass.http.app)
|
||||
|
||||
resp = yield from client.get('/api/', params={
|
||||
'api_password': 'some-pass'
|
||||
})
|
||||
|
||||
assert resp.status == 200
|
||||
logs = caplog.text
|
||||
|
||||
# Ensure we don't log API passwords
|
||||
assert '/api/' in logs
|
||||
assert 'some-pass' not in logs
|
||||
|
48
tests/components/http/test_real_ip.py
Normal file
48
tests/components/http/test_real_ip.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Test real IP middleware."""
|
||||
import asyncio
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import X_FORWARDED_FOR
|
||||
|
||||
from homeassistant.components.http.real_ip import setup_real_ip
|
||||
from homeassistant.components.http.const import KEY_REAL_IP
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_handler(request):
|
||||
"""Handler that returns the real IP as text."""
|
||||
return web.Response(text=str(request[KEY_REAL_IP]))
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ignore_x_forwarded_for(test_client):
|
||||
"""Test that we get the IP from the transport."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_real_ip(app, False)
|
||||
|
||||
mock_api_client = yield from test_client(app)
|
||||
|
||||
resp = yield from mock_api_client.get('/', headers={
|
||||
X_FORWARDED_FOR: '255.255.255.255'
|
||||
})
|
||||
assert resp.status == 200
|
||||
text = yield from resp.text()
|
||||
assert text != '255.255.255.255'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_use_x_forwarded_for(test_client):
|
||||
"""Test that we get the IP from the transport."""
|
||||
app = web.Application()
|
||||
app.router.add_get('/', mock_handler)
|
||||
setup_real_ip(app, True)
|
||||
|
||||
mock_api_client = yield from test_client(app)
|
||||
|
||||
resp = yield from mock_api_client.get('/', headers={
|
||||
X_FORWARDED_FOR: '255.255.255.255'
|
||||
})
|
||||
assert resp.status == 200
|
||||
text = yield from resp.text()
|
||||
assert text == '255.255.255.255'
|
@ -4,8 +4,7 @@ from unittest.mock import Mock, MagicMock, patch
|
||||
from homeassistant.setup import setup_component
|
||||
import homeassistant.components.mqtt as mqtt
|
||||
|
||||
from tests.common import (
|
||||
get_test_home_assistant, mock_coro, mock_http_component)
|
||||
from tests.common import get_test_home_assistant, mock_coro
|
||||
|
||||
|
||||
class TestMQTT:
|
||||
@ -14,7 +13,9 @@ class TestMQTT:
|
||||
def setup_method(self, method):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
mock_http_component(self.hass, 'super_secret')
|
||||
setup_component(self.hass, 'http', {
|
||||
'api_password': 'super_secret'
|
||||
})
|
||||
|
||||
def teardown_method(self, method):
|
||||
"""Stop everything that was started."""
|
||||
|
@ -4,12 +4,10 @@ import json
|
||||
from unittest.mock import patch, MagicMock, mock_open
|
||||
from aiohttp.hdrs import AUTHORIZATION
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.json import save_json
|
||||
from homeassistant.components.notify import html5
|
||||
|
||||
from tests.common import mock_http_component_app
|
||||
|
||||
CONFIG_FILE = 'file.conf'
|
||||
|
||||
SUBSCRIPTION_1 = {
|
||||
@ -52,6 +50,23 @@ REGISTER_URL = '/api/notify.html5'
|
||||
PUBLISH_URL = '/api/notify.html5/callback'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def mock_client(hass, test_client, registrations=None):
|
||||
"""Create a test client for HTML5 views."""
|
||||
if registrations is None:
|
||||
registrations = {}
|
||||
|
||||
with patch('homeassistant.components.notify.html5._load_config',
|
||||
return_value=registrations):
|
||||
yield from async_setup_component(hass, 'notify', {
|
||||
'notify': {
|
||||
'platform': 'html5'
|
||||
}
|
||||
})
|
||||
|
||||
return (yield from test_client(hass.http.app))
|
||||
|
||||
|
||||
class TestHtml5Notify(object):
|
||||
"""Tests for HTML5 notify platform."""
|
||||
|
||||
@ -89,8 +104,6 @@ class TestHtml5Notify(object):
|
||||
service.send_message('Hello', target=['device', 'non_existing'],
|
||||
data={'icon': 'beer.png'})
|
||||
|
||||
print(mock_wp.mock_calls)
|
||||
|
||||
assert len(mock_wp.mock_calls) == 3
|
||||
|
||||
# WebPusher constructor
|
||||
@ -104,421 +117,224 @@ class TestHtml5Notify(object):
|
||||
assert payload['body'] == 'Hello'
|
||||
assert payload['icon'] == 'beer.png'
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_view(self, loop, test_client):
|
||||
"""Test that the HTML view works."""
|
||||
hass = MagicMock()
|
||||
expected = {
|
||||
'unnamed device': SUBSCRIPTION_1,
|
||||
}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_view(hass, test_client):
|
||||
"""Test that the HTML view works."""
|
||||
client = yield from mock_client(hass, test_client)
|
||||
|
||||
assert service is not None
|
||||
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == {}
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
assert resp.status == 200
|
||||
assert len(mock_save.mock_calls) == 1
|
||||
assert mock_save.mock_calls[0][1][1] == {
|
||||
'unnamed device': SUBSCRIPTION_1,
|
||||
}
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_expiration_view(self, loop, test_client):
|
||||
"""Test that the HTML view works."""
|
||||
hass = MagicMock()
|
||||
expected = {
|
||||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_expiration_view(hass, test_client):
|
||||
"""Test that the HTML view works."""
|
||||
client = yield from mock_client(hass, test_client)
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == {}
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
assert resp.status == 200
|
||||
assert mock_save.mock_calls[0][1][1] == {
|
||||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_fails_view(self, loop, test_client):
|
||||
"""Test subs. are not altered when registering a new device fails."""
|
||||
hass = MagicMock()
|
||||
expected = {}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
html5.get_service(hass, {})
|
||||
view = hass.mock_calls[1][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
hass.async_add_job.side_effect = HomeAssistantError()
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_fails_view(hass, test_client):
|
||||
"""Test subs. are not altered when registering a new device fails."""
|
||||
registrations = {}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
with patch('homeassistant.components.notify.html5.save_json',
|
||||
side_effect=HomeAssistantError()):
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 500, content
|
||||
assert view.registrations == expected
|
||||
assert resp.status == 500
|
||||
assert registrations == {}
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_existing_device_view(self, loop, test_client):
|
||||
"""Test subscription is updated when registering existing device."""
|
||||
hass = MagicMock()
|
||||
expected = {
|
||||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
html5.get_service(hass, {})
|
||||
view = hass.mock_calls[1][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
@asyncio.coroutine
|
||||
def test_registering_existing_device_view(hass, test_client):
|
||||
"""Test subscription is updated when registering existing device."""
|
||||
registrations = {}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||
yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 200, content
|
||||
assert view.registrations == expected
|
||||
assert resp.status == 200
|
||||
assert mock_save.mock_calls[0][1][1] == {
|
||||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
assert registrations == {
|
||||
'unnamed device': SUBSCRIPTION_4,
|
||||
}
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE, expected)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_existing_device_fails_view(self, loop, test_client):
|
||||
"""Test sub. is not updated when registering existing device fails."""
|
||||
hass = MagicMock()
|
||||
expected = {
|
||||
'unnamed device': SUBSCRIPTION_1,
|
||||
}
|
||||
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
html5.get_service(hass, {})
|
||||
view = hass.mock_calls[1][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
@asyncio.coroutine
|
||||
def test_registering_existing_device_fails_view(hass, test_client):
|
||||
"""Test sub. is not updated when registering existing device fails."""
|
||||
registrations = {}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||
yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_1))
|
||||
|
||||
hass.async_add_job.side_effect = HomeAssistantError()
|
||||
mock_save.side_effect = HomeAssistantError
|
||||
resp = yield from client.post(REGISTER_URL,
|
||||
data=json.dumps(SUBSCRIPTION_4))
|
||||
|
||||
content = yield from resp.text()
|
||||
assert resp.status == 500, content
|
||||
assert view.registrations == expected
|
||||
assert resp.status == 500
|
||||
assert registrations == {
|
||||
'unnamed device': SUBSCRIPTION_1,
|
||||
}
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_validation(self, loop, test_client):
|
||||
"""Test various errors when registering a new device."""
|
||||
hass = MagicMock()
|
||||
|
||||
m = mock_open()
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
@asyncio.coroutine
|
||||
def test_registering_new_device_validation(hass, test_client):
|
||||
"""Test various errors when registering a new device."""
|
||||
client = yield from mock_client(hass, test_client)
|
||||
|
||||
assert service is not None
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'invalid browser',
|
||||
'subscription': 'sub info',
|
||||
}))
|
||||
assert resp.status == 400
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'chrome',
|
||||
}))
|
||||
assert resp.status == 400
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
with patch('homeassistant.components.notify.html5.save_json',
|
||||
return_value=False):
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'chrome',
|
||||
'subscription': 'sub info',
|
||||
}))
|
||||
assert resp.status == 400
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'invalid browser',
|
||||
'subscription': 'sub info',
|
||||
}))
|
||||
assert resp.status == 400
|
||||
@asyncio.coroutine
|
||||
def test_unregistering_device_view(hass, test_client):
|
||||
"""Test that the HTML unregister view works."""
|
||||
registrations = {
|
||||
'some device': SUBSCRIPTION_1,
|
||||
'other device': SUBSCRIPTION_2,
|
||||
}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'chrome',
|
||||
}))
|
||||
assert resp.status == 400
|
||||
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
}))
|
||||
|
||||
with patch('homeassistant.components.notify.html5.save_json',
|
||||
return_value=False):
|
||||
# resp = view.post(Request(builder.get_environ()))
|
||||
resp = yield from client.post(REGISTER_URL, data=json.dumps({
|
||||
'browser': 'chrome',
|
||||
'subscription': 'sub info',
|
||||
}))
|
||||
assert resp.status == 200
|
||||
assert len(mock_save.mock_calls) == 1
|
||||
assert registrations == {
|
||||
'other device': SUBSCRIPTION_2
|
||||
}
|
||||
|
||||
assert resp.status == 400
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unregistering_device_view(self, loop, test_client):
|
||||
"""Test that the HTML unregister view works."""
|
||||
hass = MagicMock()
|
||||
@asyncio.coroutine
|
||||
def test_unregister_device_view_handle_unknown_subscription(hass, test_client):
|
||||
"""Test that the HTML unregister view handles unknown subscriptions."""
|
||||
registrations = {}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
config = {
|
||||
'some device': SUBSCRIPTION_1,
|
||||
'other device': SUBSCRIPTION_2,
|
||||
}
|
||||
with patch('homeassistant.components.notify.html5.save_json') as mock_save:
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_3['subscription']
|
||||
}))
|
||||
|
||||
m = mock_open(read_data=json.dumps(config))
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
assert resp.status == 200, resp.response
|
||||
assert registrations == {}
|
||||
assert len(mock_save.mock_calls) == 0
|
||||
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
@asyncio.coroutine
|
||||
def test_unregistering_device_view_handles_save_error(hass, test_client):
|
||||
"""Test that the HTML unregister view handles save errors."""
|
||||
registrations = {
|
||||
'some device': SUBSCRIPTION_1,
|
||||
'other device': SUBSCRIPTION_2,
|
||||
}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == config
|
||||
with patch('homeassistant.components.notify.html5.save_json',
|
||||
side_effect=HomeAssistantError()):
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
}))
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
assert resp.status == 500, resp.response
|
||||
assert registrations == {
|
||||
'some device': SUBSCRIPTION_1,
|
||||
'other device': SUBSCRIPTION_2,
|
||||
}
|
||||
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
}))
|
||||
|
||||
config.pop('some device')
|
||||
@asyncio.coroutine
|
||||
def test_callback_view_no_jwt(hass, test_client):
|
||||
"""Test that the notification callback view works without JWT."""
|
||||
client = yield from mock_client(hass, test_client)
|
||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
||||
'type': 'push',
|
||||
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
|
||||
}))
|
||||
|
||||
assert resp.status == 200, resp.response
|
||||
assert view.registrations == config
|
||||
assert resp.status == 401, resp.response
|
||||
|
||||
hass.async_add_job.assert_called_with(save_json, CONFIG_FILE,
|
||||
config)
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unregister_device_view_handle_unknown_subscription(
|
||||
self, loop, test_client):
|
||||
"""Test that the HTML unregister view handles unknown subscriptions."""
|
||||
hass = MagicMock()
|
||||
@asyncio.coroutine
|
||||
def test_callback_view_with_jwt(hass, test_client):
|
||||
"""Test that the notification callback view works with JWT."""
|
||||
registrations = {
|
||||
'device': SUBSCRIPTION_1
|
||||
}
|
||||
client = yield from mock_client(hass, test_client, registrations)
|
||||
|
||||
config = {
|
||||
'some device': SUBSCRIPTION_1,
|
||||
'other device': SUBSCRIPTION_2,
|
||||
}
|
||||
with patch('pywebpush.WebPusher') as mock_wp:
|
||||
yield from hass.services.async_call('notify', 'notify', {
|
||||
'message': 'Hello',
|
||||
'target': ['device'],
|
||||
'data': {'icon': 'beer.png'}
|
||||
}, blocking=True)
|
||||
|
||||
m = mock_open(read_data=json.dumps(config))
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
assert len(mock_wp.mock_calls) == 3
|
||||
|
||||
assert service is not None
|
||||
# WebPusher constructor
|
||||
assert mock_wp.mock_calls[0][1][0] == \
|
||||
SUBSCRIPTION_1['subscription']
|
||||
# Third mock_call checks the status_code of the response.
|
||||
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
# Call to send
|
||||
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == config
|
||||
assert push_payload['body'] == 'Hello'
|
||||
assert push_payload['icon'] == 'beer.png'
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
|
||||
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_3['subscription']
|
||||
}))
|
||||
resp = yield from client.post(PUBLISH_URL, json={
|
||||
'type': 'push',
|
||||
}, headers={AUTHORIZATION: bearer_token})
|
||||
|
||||
assert resp.status == 200, resp.response
|
||||
assert view.registrations == config
|
||||
|
||||
hass.async_add_job.assert_not_called()
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unregistering_device_view_handles_save_error(
|
||||
self, loop, test_client):
|
||||
"""Test that the HTML unregister view handles save errors."""
|
||||
hass = MagicMock()
|
||||
|
||||
config = {
|
||||
'some device': SUBSCRIPTION_1,
|
||||
'other device': SUBSCRIPTION_2,
|
||||
}
|
||||
|
||||
m = mock_open(read_data=json.dumps(config))
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
view = hass.mock_calls[1][1][0]
|
||||
assert view.json_path == hass.config.path.return_value
|
||||
assert view.registrations == config
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
hass.async_add_job.side_effect = HomeAssistantError()
|
||||
resp = yield from client.delete(REGISTER_URL, data=json.dumps({
|
||||
'subscription': SUBSCRIPTION_1['subscription'],
|
||||
}))
|
||||
|
||||
assert resp.status == 500, resp.response
|
||||
assert view.registrations == config
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_callback_view_no_jwt(self, loop, test_client):
|
||||
"""Test that the notification callback view works without JWT."""
|
||||
hass = MagicMock()
|
||||
|
||||
m = mock_open()
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {})
|
||||
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
view = hass.mock_calls[2][1][0]
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
||||
'type': 'push',
|
||||
'tag': '3bc28d69-0921-41f1-ac6a-7a627ba0aa72'
|
||||
}))
|
||||
|
||||
assert resp.status == 401, resp.response
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_callback_view_with_jwt(self, loop, test_client):
|
||||
"""Test that the notification callback view works with JWT."""
|
||||
hass = MagicMock()
|
||||
|
||||
data = {
|
||||
'device': SUBSCRIPTION_1
|
||||
}
|
||||
|
||||
m = mock_open(read_data=json.dumps(data))
|
||||
with patch(
|
||||
'homeassistant.util.json.open',
|
||||
m, create=True
|
||||
):
|
||||
hass.config.path.return_value = CONFIG_FILE
|
||||
service = html5.get_service(hass, {'gcm_sender_id': '100'})
|
||||
|
||||
assert service is not None
|
||||
|
||||
# assert hass.called
|
||||
assert len(hass.mock_calls) == 3
|
||||
|
||||
with patch('pywebpush.WebPusher') as mock_wp:
|
||||
service.send_message(
|
||||
'Hello', target=['device'], data={'icon': 'beer.png'})
|
||||
|
||||
assert len(mock_wp.mock_calls) == 3
|
||||
|
||||
# WebPusher constructor
|
||||
assert mock_wp.mock_calls[0][1][0] == \
|
||||
SUBSCRIPTION_1['subscription']
|
||||
# Third mock_call checks the status_code of the response.
|
||||
assert mock_wp.mock_calls[2][0] == '().send().status_code.__eq__'
|
||||
|
||||
# Call to send
|
||||
push_payload = json.loads(mock_wp.mock_calls[1][1][0])
|
||||
|
||||
assert push_payload['body'] == 'Hello'
|
||||
assert push_payload['icon'] == 'beer.png'
|
||||
|
||||
view = hass.mock_calls[2][1][0]
|
||||
view.registrations = data
|
||||
|
||||
bearer_token = "Bearer {}".format(push_payload['data']['jwt'])
|
||||
|
||||
hass.loop = loop
|
||||
app = mock_http_component_app(hass)
|
||||
view.register(app.router)
|
||||
client = yield from test_client(app)
|
||||
hass.http.is_banned_ip.return_value = False
|
||||
|
||||
resp = yield from client.post(PUBLISH_URL, data=json.dumps({
|
||||
'type': 'push',
|
||||
}), headers={AUTHORIZATION: bearer_token})
|
||||
|
||||
assert resp.status == 200
|
||||
body = yield from resp.json()
|
||||
assert body == {"event": "push", "status": "ok"}
|
||||
assert resp.status == 200
|
||||
body = yield from resp.json()
|
||||
assert body == {"event": "push", "status": "ok"}
|
||||
|
@ -10,8 +10,7 @@ import homeassistant.util.dt as dt_util
|
||||
from homeassistant.components import history, recorder
|
||||
|
||||
from tests.common import (
|
||||
init_recorder_component, mock_http_component, mock_state_change_event,
|
||||
get_test_home_assistant)
|
||||
init_recorder_component, mock_state_change_event, get_test_home_assistant)
|
||||
|
||||
|
||||
class TestComponentHistory(unittest.TestCase):
|
||||
@ -38,7 +37,6 @@ class TestComponentHistory(unittest.TestCase):
|
||||
|
||||
def test_setup(self):
|
||||
"""Test setup method of history."""
|
||||
mock_http_component(self.hass)
|
||||
config = history.CONFIG_SCHEMA({
|
||||
# ha.DOMAIN: {},
|
||||
history.DOMAIN: {
|
||||
|
@ -14,7 +14,7 @@ from homeassistant.components import logbook
|
||||
from homeassistant.setup import setup_component
|
||||
|
||||
from tests.common import (
|
||||
mock_http_component, init_recorder_component, get_test_home_assistant)
|
||||
init_recorder_component, get_test_home_assistant)
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -29,10 +29,7 @@ class TestComponentLogbook(unittest.TestCase):
|
||||
"""Setup things to be run when tests are started."""
|
||||
self.hass = get_test_home_assistant()
|
||||
init_recorder_component(self.hass) # Force an in memory DB
|
||||
mock_http_component(self.hass)
|
||||
self.hass.config.components |= set(['frontend', 'recorder', 'api'])
|
||||
assert setup_component(self.hass, logbook.DOMAIN,
|
||||
self.EMPTY_CONFIG)
|
||||
assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG)
|
||||
self.hass.start()
|
||||
|
||||
def tearDown(self):
|
||||
|
@ -150,7 +150,6 @@ def test_api_update_fails(hass, test_client):
|
||||
assert resp.status == 404
|
||||
|
||||
beer_id = hass.data['shopping_list'].items[0]['id']
|
||||
client = yield from test_client(hass.http.app)
|
||||
resp = yield from client.post(
|
||||
'/api/shopping_list/item/{}'.format(beer_id), json={
|
||||
'name': 123,
|
||||
|
@ -8,8 +8,9 @@ import pytest
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api as wapi, frontend
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_http_component_app, mock_coro
|
||||
from tests.common import mock_coro
|
||||
|
||||
API_PASSWORD = 'test1234'
|
||||
|
||||
@ -17,10 +18,10 @@ API_PASSWORD = 'test1234'
|
||||
@pytest.fixture
|
||||
def websocket_client(loop, hass, test_client):
|
||||
"""Websocket client fixture connected to websocket server."""
|
||||
websocket_app = mock_http_component_app(hass)
|
||||
wapi.WebsocketAPIView().register(websocket_app.router)
|
||||
assert loop.run_until_complete(
|
||||
async_setup_component(hass, 'websocket_api'))
|
||||
|
||||
client = loop.run_until_complete(test_client(websocket_app))
|
||||
client = loop.run_until_complete(test_client(hass.http.app))
|
||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||
|
||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||
@ -35,10 +36,14 @@ def websocket_client(loop, hass, test_client):
|
||||
@pytest.fixture
|
||||
def no_auth_websocket_client(hass, loop, test_client):
|
||||
"""Websocket connection that requires authentication."""
|
||||
websocket_app = mock_http_component_app(hass, API_PASSWORD)
|
||||
wapi.WebsocketAPIView().register(websocket_app.router)
|
||||
assert loop.run_until_complete(
|
||||
async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
}))
|
||||
|
||||
client = loop.run_until_complete(test_client(websocket_app))
|
||||
client = loop.run_until_complete(test_client(hass.http.app))
|
||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||
|
||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||
|
Loading…
x
Reference in New Issue
Block a user