Advanced Ip filtering (#4424)

* Added IP Bans configuration

* Fixing warnings

* Added ban enabled option and unit tests

* Fixed py34 tox

* http: requested changes fix

* Requested changes fix
This commit is contained in:
Vlad Korniev 2016-11-24 21:52:10 -08:00 committed by Paulus Schoutsen
parent 95b439fbd5
commit 2a7bc0e55c
6 changed files with 225 additions and 18 deletions

View File

@ -75,9 +75,12 @@ def setup(hass, yaml_config):
api_password=None, api_password=None,
ssl_certificate=None, ssl_certificate=None,
ssl_key=None, ssl_key=None,
cors_origins=[], cors_origins=None,
use_x_forwarded_for=False, use_x_forwarded_for=False,
trusted_networks=[] trusted_networks=None,
ip_bans=None,
login_threshold=0,
is_ban_enabled=False
) )
server.register_view(DescriptionXmlView(hass, config)) server.register_view(DescriptionXmlView(hass, config))

View File

@ -5,32 +5,36 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/components/http/ https://home-assistant.io/components/http/
""" """
import asyncio import asyncio
import hmac
import json import json
import logging import logging
import mimetypes import mimetypes
import os
from pathlib import Path
import re
import ssl import ssl
from datetime import datetime
from ipaddress import ip_address, ip_network from ipaddress import ip_address, ip_network
from pathlib import Path
import hmac
import os
import re
import voluptuous as vol import voluptuous as vol
from aiohttp import web, hdrs from aiohttp import web, hdrs
from aiohttp.file_sender import FileSender from aiohttp.file_sender import FileSender
from aiohttp.web_exceptions import ( from aiohttp.web_exceptions import (
HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified) HTTPUnauthorized, HTTPMovedPermanently, HTTPNotModified, HTTPForbidden)
from aiohttp.web_urldispatcher import StaticResource from aiohttp.web_urldispatcher import StaticResource
from homeassistant.core import is_callback import homeassistant.helpers.config_validation as cv
import homeassistant.remote as rem import homeassistant.remote as rem
from homeassistant import util from homeassistant import util
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.const import ( from homeassistant.const import (
SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL, SERVER_PORT, HTTP_HEADER_HA_AUTH, # HTTP_HEADER_CACHE_CONTROL,
CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP, CONTENT_TYPE_JSON, ALLOWED_CORS_HEADERS, EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR) EVENT_HOMEASSISTANT_START, HTTP_HEADER_X_FORWARDED_FOR)
import homeassistant.helpers.config_validation as cv from homeassistant.core import is_callback
from homeassistant.components import persistent_notification from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.yaml import dump
DOMAIN = 'http' DOMAIN = 'http'
REQUIREMENTS = ('aiohttp_cors==0.5.0',) REQUIREMENTS = ('aiohttp_cors==0.5.0',)
@ -44,9 +48,16 @@ CONF_SSL_KEY = 'ssl_key'
CONF_CORS_ORIGINS = 'cors_allowed_origins' CONF_CORS_ORIGINS = 'cors_allowed_origins'
CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for' CONF_USE_X_FORWARDED_FOR = 'use_x_forwarded_for'
CONF_TRUSTED_NETWORKS = 'trusted_networks' CONF_TRUSTED_NETWORKS = 'trusted_networks'
CONF_LOGIN_ATTEMPTS_THRESHOLD = 'login_attempts_threshold'
CONF_IP_BAN_ENABLED = 'ip_ban_enabled'
DATA_API_PASSWORD = 'api_password' DATA_API_PASSWORD = 'api_password'
NOTIFICATION_ID_LOGIN = 'http-login' NOTIFICATION_ID_LOGIN = 'http-login'
NOTIFICATION_ID_BAN = 'ip-ban'
IP_BANS = 'ip_bans.yaml'
ATTR_BANNED_AT = "banned_at"
# TLS configuation follows the best-practice guidelines specified here: # TLS configuation follows the best-practice guidelines specified here:
# https://wiki.mozilla.org/Security/Server_Side_TLS # https://wiki.mozilla.org/Security/Server_Side_TLS
@ -85,7 +96,9 @@ CONFIG_SCHEMA = vol.Schema({
vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]), vol.Optional(CONF_CORS_ORIGINS): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean, vol.Optional(CONF_USE_X_FORWARDED_FOR, default=False): cv.boolean,
vol.Optional(CONF_TRUSTED_NETWORKS): vol.Optional(CONF_TRUSTED_NETWORKS):
vol.All(cv.ensure_list, [ip_network]) vol.All(cv.ensure_list, [ip_network]),
vol.Optional(CONF_LOGIN_ATTEMPTS_THRESHOLD): cv.positive_int,
vol.Optional(CONF_IP_BAN_ENABLED): cv.boolean
}), }),
}, extra=vol.ALLOW_EXTRA) }, extra=vol.ALLOW_EXTRA)
@ -131,6 +144,9 @@ def setup(hass, config):
trusted_networks = [ trusted_networks = [
ip_network(trusted_network) ip_network(trusted_network)
for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])] for trusted_network in conf.get(CONF_TRUSTED_NETWORKS, [])]
is_ban_enabled = bool(conf.get(CONF_IP_BAN_ENABLED, False))
login_threshold = int(conf.get(CONF_LOGIN_ATTEMPTS_THRESHOLD, -1))
ip_bans = load_ip_bans_config(hass.config.path(IP_BANS))
server = HomeAssistantWSGI( server = HomeAssistantWSGI(
hass, hass,
@ -142,7 +158,10 @@ def setup(hass, config):
ssl_key=ssl_key, ssl_key=ssl_key,
cors_origins=cors_origins, cors_origins=cors_origins,
use_x_forwarded_for=use_x_forwarded_for, use_x_forwarded_for=use_x_forwarded_for,
trusted_networks=trusted_networks trusted_networks=trusted_networks,
ip_bans=ip_bans,
login_threshold=login_threshold,
is_ban_enabled=is_ban_enabled
) )
@asyncio.coroutine @asyncio.coroutine
@ -254,7 +273,8 @@ class HomeAssistantWSGI(object):
def __init__(self, hass, development, api_password, ssl_certificate, def __init__(self, hass, development, api_password, ssl_certificate,
ssl_key, server_host, server_port, cors_origins, ssl_key, server_host, server_port, cors_origins,
use_x_forwarded_for, trusted_networks): use_x_forwarded_for, trusted_networks,
ip_bans, login_threshold, is_ban_enabled):
"""Initialize the WSGI Home Assistant server.""" """Initialize the WSGI Home Assistant server."""
import aiohttp_cors import aiohttp_cors
@ -268,10 +288,15 @@ class HomeAssistantWSGI(object):
self.server_host = server_host self.server_host = server_host
self.server_port = server_port self.server_port = server_port
self.use_x_forwarded_for = use_x_forwarded_for self.use_x_forwarded_for = use_x_forwarded_for
self.trusted_networks = trusted_networks self.trusted_networks = trusted_networks \
if trusted_networks is not None else []
self.event_forwarder = None self.event_forwarder = None
self._handler = None self._handler = None
self.server = None self.server = None
self.login_threshold = login_threshold
self.ip_bans = ip_bans if ip_bans is not None else []
self.failed_login_attempts = {}
self.is_ban_enabled = is_ban_enabled
if cors_origins: if cors_origins:
self.cors = aiohttp_cors.setup(self.app, defaults={ self.cors = aiohttp_cors.setup(self.app, defaults={
@ -385,6 +410,39 @@ class HomeAssistantWSGI(object):
return any(ip_address(remote_addr) in trusted_network return any(ip_address(remote_addr) in trusted_network
for trusted_network in self.hass.http.trusted_networks) for trusted_network in self.hass.http.trusted_networks)
def wrong_login_attempt(self, remote_addr):
"""Registering wrong login attempt."""
if not self.is_ban_enabled or self.login_threshold < 1:
return
if remote_addr in self.failed_login_attempts:
self.failed_login_attempts[remote_addr] += 1
else:
self.failed_login_attempts[remote_addr] = 1
if self.failed_login_attempts[remote_addr] > self.login_threshold:
new_ban = IpBan(remote_addr)
self.ip_bans.append(new_ban)
update_ip_bans_config(self.hass.config.path(IP_BANS), new_ban)
_LOGGER.warning('Banned IP %s for too many login attempts',
remote_addr)
persistent_notification.async_create(
self.hass,
'Too many login attempts from {}'.format(remote_addr),
'Banning IP address', NOTIFICATION_ID_BAN)
def is_banned_ip(self, remote_addr):
"""Check if IP address is in a ban list."""
if not self.is_ban_enabled:
return False
ip_address_ = ip_address(remote_addr)
for ip_ban in self.ip_bans:
if ip_ban.ip_address == ip_address_:
return True
return False
class HomeAssistantView(object): class HomeAssistantView(object):
"""Base view for all views.""" """Base view for all views."""
@ -465,6 +523,9 @@ def request_handler_factory(view, handler):
remote_addr = view.hass.http.get_real_ip(request) remote_addr = view.hass.http.get_real_ip(request)
if view.hass.http.is_banned_ip(remote_addr):
raise HTTPForbidden()
# Auth code verbose on purpose # Auth code verbose on purpose
authenticated = False authenticated = False
@ -484,6 +545,7 @@ def request_handler_factory(view, handler):
authenticated = True authenticated = True
if view.requires_auth and not authenticated: if view.requires_auth and not authenticated:
view.hass.http.wrong_login_attempt(remote_addr)
_LOGGER.warning('Login attempt or request with an invalid ' _LOGGER.warning('Login attempt or request with an invalid '
'password from %s', remote_addr) 'password from %s', remote_addr)
persistent_notification.async_create( persistent_notification.async_create(
@ -525,3 +587,55 @@ def request_handler_factory(view, handler):
return web.Response(body=result, status=status_code) return web.Response(body=result, status=status_code)
return handle return handle
class IpBan(object):
"""Represents banned IP address."""
def __init__(self, ip_ban: str, banned_at: datetime=None) -> None:
"""Initializing Ip Ban object."""
self.ip_address = ip_address(ip_ban)
self.banned_at = banned_at
if self.banned_at is None:
self.banned_at = datetime.utcnow()
def load_ip_bans_config(path: str):
"""Loading list of banned IPs from config file."""
ip_list = []
ip_schema = vol.Schema({
vol.Optional('banned_at'): vol.Any(None, cv.datetime)
})
try:
try:
list_ = load_yaml_config_file(path)
except HomeAssistantError as err:
_LOGGER.error('Unable to load %s: %s', path, str(err))
return []
for ip_ban, ip_info in list_.items():
try:
ip_info = ip_schema(ip_info)
ip_info['ip_ban'] = ip_address(ip_ban)
ip_list.append(IpBan(**ip_info))
except vol.Invalid:
_LOGGER.exception('Failed to load IP ban')
continue
except(HomeAssistantError, FileNotFoundError):
# No need to report error, file absence means
# that no bans were applied.
return []
return ip_list
def update_ip_bans_config(path: str, ip_ban: IpBan):
"""Update config file with new banned IP address."""
with open(path, 'a') as out:
ip_ = {str(ip_ban.ip_address): {
ATTR_BANNED_AT: ip_ban.banned_at.strftime("%Y-%m-%dT%H:%M:%S")
}}
out.write('\n')
out.write(dump(ip_))

View File

@ -1,6 +1,6 @@
"""Helpers for config validation using voluptuous.""" """Helpers for config validation using voluptuous."""
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta, datetime as datetime_sys
import os import os
import re import re
from urllib.parse import urlparse from urllib.parse import urlparse
@ -297,6 +297,22 @@ def time(value):
return time_val return time_val
def datetime(value):
"""Validate datetime."""
if isinstance(value, datetime_sys):
return value
try:
date_val = dt_util.parse_datetime(value)
except TypeError:
date_val = None
if date_val is None:
raise vol.Invalid('Invalid datetime specified: {}'.format(value))
return date_val
def time_zone(value): def time_zone(value):
"""Validate timezone.""" """Validate timezone."""
if dt_util.get_time_zone(value) is not None: if dt_util.get_time_zone(value) is not None:

View File

@ -124,6 +124,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL, resp = yield from client.post(REGISTER_URL,
data=json.dumps(SUBSCRIPTION_1)) data=json.dumps(SUBSCRIPTION_1))
@ -155,6 +156,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(REGISTER_URL, data=json.dumps({ resp = yield from client.post(REGISTER_URL, data=json.dumps({
'browser': 'invalid browser', 'browser': 'invalid browser',
@ -209,6 +211,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.delete(REGISTER_URL, data=json.dumps({ resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_1['subscription'], 'subscription': SUBSCRIPTION_1['subscription'],
@ -253,6 +256,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.delete(REGISTER_URL, data=json.dumps({ resp = yield from client.delete(REGISTER_URL, data=json.dumps({
'subscription': SUBSCRIPTION_3['subscription'] 'subscription': SUBSCRIPTION_3['subscription']
@ -295,6 +299,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
with patch('homeassistant.components.notify.html5._save_config', with patch('homeassistant.components.notify.html5._save_config',
return_value=False): return_value=False):
@ -329,6 +334,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(PUBLISH_URL, data=json.dumps({ resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push', 'type': 'push',
@ -384,6 +390,7 @@ class TestHtml5Notify(object):
app = web.Application(loop=loop) app = web.Application(loop=loop)
view.register(app.router) view.register(app.router)
client = yield from test_client(app) client = yield from test_client(app)
hass.http.is_banned_ip.return_value = False
resp = yield from client.post(PUBLISH_URL, data=json.dumps({ resp = yield from client.post(PUBLISH_URL, data=json.dumps({
'type': 'push', 'type': 'push',

View File

@ -2,7 +2,7 @@
# pylint: disable=protected-access # pylint: disable=protected-access
import logging import logging
from ipaddress import ip_network from ipaddress import ip_network
from unittest.mock import patch from unittest.mock import patch, mock_open
import requests import requests
@ -25,7 +25,7 @@ TRUSTED_NETWORKS = ['192.0.2.0/24', '2001:DB8:ABCD::/48', '100.64.0.1',
TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1', TRUSTED_ADDRESSES = ['100.64.0.1', '192.0.2.100', 'FD01:DB8::1',
'2001:DB8:ABCD::1'] '2001:DB8:ABCD::1']
UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1'] UNTRUSTED_ADDRESSES = ['198.51.100.1', '2001:DB8:FA1::1', '127.0.0.1', '::1']
BANNED_IPS = ['200.201.202.203', '100.64.0.1']
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE] CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
@ -63,6 +63,9 @@ def setUpModule():
ip_network(trusted_network) ip_network(trusted_network)
for trusted_network in TRUSTED_NETWORKS] for trusted_network in TRUSTED_NETWORKS]
hass.http.ip_bans = [http.IpBan(banned_ip)
for banned_ip in BANNED_IPS]
hass.start() hass.start()
@ -227,3 +230,56 @@ class TestHttp:
assert req.headers.get(allow_origin) == HTTP_BASE_URL assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == \ assert req.headers.get(allow_headers) == \
const.HTTP_HEADER_HA_AUTH.upper() const.HTTP_HEADER_HA_AUTH.upper()
def test_access_from_banned_ip(self):
"""Test accessing to server from banned IP. Both trusted and not."""
hass.http.is_ban_enabled = True
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API))
assert req.status_code == 403
def test_access_from_banned_ip_when_ban_is_off(self):
"""Test accessing to server from banned IP when feature is off"""
hass.http.is_ban_enabled = False
for remote_addr in BANNED_IPS:
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value=remote_addr):
req = requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: API_PASSWORD})
assert req.status_code == 200
def test_ip_bans_file_creation(self):
"""Testing if banned IP file created"""
hass.http.is_ban_enabled = True
hass.http.login_threshold = 1
m = mock_open()
def call_server():
with patch('homeassistant.components.http.'
'HomeAssistantWSGI.get_real_ip',
return_value="200.201.202.204"):
return requests.get(
_url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'Wrong password'})
with patch('homeassistant.components.http.open', m, create=True):
req = call_server()
assert req.status_code == 401
assert len(hass.http.ip_bans) == len(BANNED_IPS)
assert m.call_count == 0
req = call_server()
assert req.status_code == 401
assert len(hass.http.ip_bans) == len(BANNED_IPS) + 1
m.assert_called_once_with(hass.config.path(http.IP_BANS), 'a')
req = call_server()
assert req.status_code == 403
assert m.call_count == 1

View File

@ -1,6 +1,6 @@
"""Test config validators.""" """Test config validators."""
from collections import OrderedDict from collections import OrderedDict
from datetime import timedelta from datetime import timedelta, datetime, date
import enum import enum
import os import os
from socket import _GLOBAL_DEFAULT_TIMEOUT from socket import _GLOBAL_DEFAULT_TIMEOUT
@ -358,6 +358,17 @@ def test_time_zone():
schema('UTC') schema('UTC')
def test_datetime():
"""Test date time validation."""
schema = vol.Schema(cv.datetime)
for value in [date.today(), 'Wrong DateTime', '2016-11-23']:
with pytest.raises(vol.MultipleInvalid):
schema(value)
schema(datetime.now())
schema('2016-11-23T18:59:08')
def test_key_dependency(): def test_key_dependency():
"""Test key_dependency validator.""" """Test key_dependency validator."""
schema = vol.Schema(cv.key_dependency('beer', 'soda')) schema = vol.Schema(cv.key_dependency('beer', 'soda'))