Switch linear search to a dict lookup for ip bans (#74482)

This commit is contained in:
J. Nick Koston 2022-07-07 03:57:44 -05:00 committed by GitHub
parent ae295f1bf5
commit 0c29b68cf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 58 deletions

View File

@ -26,7 +26,7 @@ from .view import HomeAssistantView
_LOGGER: Final = logging.getLogger(__name__) _LOGGER: Final = logging.getLogger(__name__)
KEY_BANNED_IPS: Final = "ha_banned_ips" KEY_BAN_MANAGER: Final = "ha_banned_ips_manager"
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts" KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold" KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
@ -50,9 +50,9 @@ def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> N
async def ban_startup(app: Application) -> None: async def ban_startup(app: Application) -> None:
"""Initialize bans when app starts up.""" """Initialize bans when app starts up."""
app[KEY_BANNED_IPS] = await async_load_ip_bans_config( ban_manager = IpBanManager(hass)
hass, hass.config.path(IP_BANS_FILE) await ban_manager.async_load()
) app[KEY_BAN_MANAGER] = ban_manager
app.on_startup.append(ban_startup) app.on_startup.append(ban_startup)
@ -62,18 +62,17 @@ async def ban_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse: ) -> StreamResponse:
"""IP Ban middleware.""" """IP Ban middleware."""
if KEY_BANNED_IPS not in request.app: ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER)
if ban_manager is None:
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded") _LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
return await handler(request) return await handler(request)
# Verify if IP is not banned ip_bans_lookup = ban_manager.ip_bans_lookup
ip_address_ = ip_address(request.remote) # type: ignore[arg-type] if ip_bans_lookup:
is_banned = any( # Verify if IP is not banned
ip_ban.ip_address == ip_address_ for ip_ban in request.app[KEY_BANNED_IPS] ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
) if ip_address_ in ip_bans_lookup:
raise HTTPForbidden()
if is_banned:
raise HTTPForbidden()
try: try:
return await handler(request) return await handler(request)
@ -129,7 +128,7 @@ async def process_wrong_login(request: Request) -> None:
) )
# Check if ban middleware is loaded # Check if ban middleware is loaded
if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1: if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
return return
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1 request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] += 1
@ -146,14 +145,9 @@ async def process_wrong_login(request: Request) -> None:
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr]
>= request.app[KEY_LOGIN_THRESHOLD] >= request.app[KEY_LOGIN_THRESHOLD]
): ):
new_ban = IpBan(remote_addr) ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER]
request.app[KEY_BANNED_IPS].append(new_ban)
await hass.async_add_executor_job(
update_ip_bans_config, hass.config.path(IP_BANS_FILE), new_ban
)
_LOGGER.warning("Banned IP %s for too many login attempts", remote_addr) _LOGGER.warning("Banned IP %s for too many login attempts", remote_addr)
await ban_manager.async_add_ban(remote_addr)
persistent_notification.async_create( persistent_notification.async_create(
hass, hass,
@ -173,7 +167,7 @@ async def process_success_login(request: Request) -> None:
remote_addr = ip_address(request.remote) # type: ignore[arg-type] remote_addr = ip_address(request.remote) # type: ignore[arg-type]
# Check if ban middleware is loaded # Check if ban middleware is loaded
if KEY_BANNED_IPS not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1: if KEY_BAN_MANAGER not in request.app or request.app[KEY_LOGIN_THRESHOLD] < 1:
return return
if ( if (
@ -199,32 +193,49 @@ class IpBan:
self.banned_at = banned_at or dt_util.utcnow() self.banned_at = banned_at or dt_util.utcnow()
async def async_load_ip_bans_config(hass: HomeAssistant, path: str) -> list[IpBan]: class IpBanManager:
"""Load list of banned IPs from config file.""" """Manage IP bans."""
ip_list: list[IpBan] = []
try: def __init__(self, hass: HomeAssistant) -> None:
list_ = await hass.async_add_executor_job(load_yaml_config_file, path) """Init the ban manager."""
except FileNotFoundError: self.hass = hass
return ip_list self.path = hass.config.path(IP_BANS_FILE)
except HomeAssistantError as err: self.ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {}
_LOGGER.error("Unable to load %s: %s", path, str(err))
return ip_list
for ip_ban, ip_info in list_.items(): async def async_load(self) -> None:
"""Load the existing IP bans."""
try: try:
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info) list_ = await self.hass.async_add_executor_job(
ip_list.append(IpBan(ip_ban, ip_info["banned_at"])) load_yaml_config_file, self.path
except vol.Invalid as err: )
_LOGGER.error("Failed to load IP ban %s: %s", ip_info, err) except FileNotFoundError:
continue return
except HomeAssistantError as err:
_LOGGER.error("Unable to load %s: %s", self.path, str(err))
return
return ip_list ip_bans_lookup: dict[IPv4Address | IPv6Address, IpBan] = {}
for ip_ban, ip_info in list_.items():
try:
ip_info = SCHEMA_IP_BAN_ENTRY(ip_info)
ban = IpBan(ip_ban, ip_info["banned_at"])
ip_bans_lookup[ban.ip_address] = ban
except vol.Invalid as err:
_LOGGER.error("Failed to load IP ban %s: %s", ip_info, err)
continue
self.ip_bans_lookup = ip_bans_lookup
def update_ip_bans_config(path: str, ip_ban: IpBan) -> None: def _add_ban(self, ip_ban: IpBan) -> None:
"""Update config file with new banned IP address.""" """Update config file with new banned IP address."""
with open(path, "a", encoding="utf8") as out: with open(self.path, "a", encoding="utf8") as out:
ip_ = {str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}} ip_ = {
out.write("\n") str(ip_ban.ip_address): {ATTR_BANNED_AT: ip_ban.banned_at.isoformat()}
out.write(yaml.dump(ip_)) }
# Write in a single write call to avoid interleaved writes
out.write("\n" + yaml.dump(ip_))
async def async_add_ban(self, remote_addr: IPv4Address | IPv6Address) -> None:
"""Add a new IP address to the banned list."""
new_ban = self.ip_bans_lookup[remote_addr] = IpBan(remote_addr)
await self.hass.async_add_executor_job(self._add_ban, new_ban)

View File

@ -19,8 +19,7 @@ def patch_zeroconf_multiple_catcher():
def prevent_io(): def prevent_io():
"""Fixture to prevent certain I/O from happening.""" """Fixture to prevent certain I/O from happening."""
with patch( with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config", "homeassistant.components.http.ban.load_yaml_config_file",
return_value=[],
): ):
yield yield

View File

@ -15,12 +15,13 @@ import homeassistant.components.http as http
from homeassistant.components.http import KEY_AUTHENTICATED from homeassistant.components.http import KEY_AUTHENTICATED
from homeassistant.components.http.ban import ( from homeassistant.components.http.ban import (
IP_BANS_FILE, IP_BANS_FILE,
KEY_BANNED_IPS, KEY_BAN_MANAGER,
KEY_FAILED_LOGIN_ATTEMPTS, KEY_FAILED_LOGIN_ATTEMPTS,
IpBan, IpBanManager,
setup_bans, setup_bans,
) )
from homeassistant.components.http.view import request_handler_factory from homeassistant.components.http.view import request_handler_factory
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from . import mock_real_ip from . import mock_real_ip
@ -58,8 +59,10 @@ async def test_access_from_banned_ip(hass, aiohttp_client):
set_real_ip = mock_real_ip(app) set_real_ip = mock_real_ip(app)
with patch( with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config", "homeassistant.components.http.ban.load_yaml_config_file",
return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS], return_value={
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
},
): ):
client = await aiohttp_client(app) client = await aiohttp_client(app)
@ -69,6 +72,99 @@ async def test_access_from_banned_ip(hass, aiohttp_client):
assert resp.status == HTTPStatus.FORBIDDEN assert resp.status == HTTPStatus.FORBIDDEN
async def test_access_from_banned_ip_with_partially_broken_yaml_file(
hass, aiohttp_client, caplog
):
"""Test accessing to server from banned IP. Both trusted and not.
We inject some garbage into the yaml file to make sure it can
still load the bans.
"""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
data = {banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS}
data["5.3.3.3"] = {"banned_at": "garbage"}
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
return_value=data,
):
client = await aiohttp_client(app)
for remote_addr in BANNED_IPS:
set_real_ip(remote_addr)
resp = await client.get("/")
assert resp.status == HTTPStatus.FORBIDDEN
# Ensure garbage data is ignored
set_real_ip("5.3.3.3")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
assert "Failed to load IP ban" in caplog.text
async def test_no_ip_bans_file(hass, aiohttp_client):
"""Test no ip bans file."""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
side_effect=FileNotFoundError,
):
client = await aiohttp_client(app)
set_real_ip("4.3.2.1")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
async def test_failure_loading_ip_bans_file(hass, aiohttp_client):
"""Test failure loading ip bans file."""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
side_effect=HomeAssistantError,
):
client = await aiohttp_client(app)
set_real_ip("4.3.2.1")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
async def test_ip_ban_manager_never_started(hass, aiohttp_client, caplog):
"""Test we handle the ip ban manager not being started."""
app = web.Application()
app["hass"] = hass
setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app)
with patch(
"homeassistant.components.http.ban.load_yaml_config_file",
side_effect=FileNotFoundError,
):
client = await aiohttp_client(app)
# Mock the manager never being started
del app[KEY_BAN_MANAGER]
set_real_ip("4.3.2.1")
resp = await client.get("/")
assert resp.status == HTTPStatus.NOT_FOUND
assert "IP Ban middleware loaded but banned IPs not loaded" in caplog.text
@pytest.mark.parametrize( @pytest.mark.parametrize(
"remote_addr, bans, status", "remote_addr, bans, status",
list( list(
@ -95,10 +191,13 @@ async def test_access_from_supervisor_ip(
mock_real_ip(app)(remote_addr) mock_real_ip(app)(remote_addr)
with patch( with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config", return_value=[] "homeassistant.components.http.ban.load_yaml_config_file",
return_value={},
): ):
client = await aiohttp_client(app) client = await aiohttp_client(app)
manager: IpBanManager = app[KEY_BAN_MANAGER]
assert await async_setup_component(hass, "hassio", {"hassio": {}}) assert await async_setup_component(hass, "hassio", {"hassio": {}})
m_open = mock_open() m_open = mock_open()
@ -108,13 +207,13 @@ async def test_access_from_supervisor_ip(
): ):
resp = await client.get("/") resp = await client.get("/")
assert resp.status == HTTPStatus.UNAUTHORIZED assert resp.status == HTTPStatus.UNAUTHORIZED
assert len(app[KEY_BANNED_IPS]) == bans assert len(manager.ip_bans_lookup) == bans
assert m_open.call_count == bans assert m_open.call_count == bans
# second request should be forbidden if banned # second request should be forbidden if banned
resp = await client.get("/") resp = await client.get("/")
assert resp.status == status assert resp.status == status
assert len(app[KEY_BANNED_IPS]) == bans assert len(manager.ip_bans_lookup) == bans
async def test_ban_middleware_not_loaded_by_config(hass): async def test_ban_middleware_not_loaded_by_config(hass):
@ -149,22 +248,25 @@ async def test_ip_bans_file_creation(hass, aiohttp_client):
mock_real_ip(app)("200.201.202.204") mock_real_ip(app)("200.201.202.204")
with patch( with patch(
"homeassistant.components.http.ban.async_load_ip_bans_config", "homeassistant.components.http.ban.load_yaml_config_file",
return_value=[IpBan(banned_ip) for banned_ip in BANNED_IPS], return_value={
banned_ip: {"banned_at": "2016-11-16T19:20:03"} for banned_ip in BANNED_IPS
},
): ):
client = await aiohttp_client(app) client = await aiohttp_client(app)
manager: IpBanManager = app[KEY_BAN_MANAGER]
m_open = mock_open() m_open = mock_open()
with patch("homeassistant.components.http.ban.open", m_open, create=True): with patch("homeassistant.components.http.ban.open", m_open, create=True):
resp = await client.get("/") resp = await client.get("/")
assert resp.status == HTTPStatus.UNAUTHORIZED assert resp.status == HTTPStatus.UNAUTHORIZED
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) assert len(manager.ip_bans_lookup) == len(BANNED_IPS)
assert m_open.call_count == 0 assert m_open.call_count == 0
resp = await client.get("/") resp = await client.get("/")
assert resp.status == HTTPStatus.UNAUTHORIZED assert resp.status == HTTPStatus.UNAUTHORIZED
assert len(app[KEY_BANNED_IPS]) == len(BANNED_IPS) + 1 assert len(manager.ip_bans_lookup) == len(BANNED_IPS) + 1
m_open.assert_called_once_with( m_open.assert_called_once_with(
hass.config.path(IP_BANS_FILE), "a", encoding="utf8" hass.config.path(IP_BANS_FILE), "a", encoding="utf8"
) )