Use aiohttp.AppKey for http ban keys (#112657)

This commit is contained in:
Marc Mueller 2024-03-08 11:13:24 +01:00 committed by GitHub
parent 7dcf275966
commit eb8f8e1ae4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 17 deletions

View File

@ -11,7 +11,14 @@ import logging
from socket import gethostbyaddr, herror from socket import gethostbyaddr, herror
from typing import Any, Concatenate, Final, ParamSpec, TypeVar from typing import Any, Concatenate, Final, ParamSpec, TypeVar
from aiohttp.web import Application, Request, Response, StreamResponse, middleware from aiohttp.web import (
AppKey,
Application,
Request,
Response,
StreamResponse,
middleware,
)
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol import voluptuous as vol
@ -29,9 +36,11 @@ _P = ParamSpec("_P")
_LOGGER: Final = logging.getLogger(__name__) _LOGGER: Final = logging.getLogger(__name__)
KEY_BAN_MANAGER: Final = "ha_banned_ips_manager" KEY_BAN_MANAGER = AppKey["IpBanManager"]("ha_banned_ips_manager")
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts" KEY_FAILED_LOGIN_ATTEMPTS = AppKey[defaultdict[IPv4Address | IPv6Address, int]](
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold" "ha_failed_login_attempts"
)
KEY_LOGIN_THRESHOLD = AppKey[int]("ban_manager.ip_bans_lookup")
NOTIFICATION_ID_BAN: Final = "ip-ban" NOTIFICATION_ID_BAN: Final = "ip-ban"
NOTIFICATION_ID_LOGIN: Final = "http-login" NOTIFICATION_ID_LOGIN: Final = "http-login"
@ -48,7 +57,7 @@ SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None: def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
"""Create IP Ban middleware for the app.""" """Create IP Ban middleware for the app."""
app.middlewares.append(ban_middleware) app.middlewares.append(ban_middleware)
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict[IPv4Address | IPv6Address, int](int)
app[KEY_LOGIN_THRESHOLD] = login_threshold app[KEY_LOGIN_THRESHOLD] = login_threshold
app[KEY_BAN_MANAGER] = IpBanManager(hass) app[KEY_BAN_MANAGER] = IpBanManager(hass)
@ -64,13 +73,11 @@ async def ban_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse: ) -> StreamResponse:
"""IP Ban middleware.""" """IP Ban middleware."""
ban_manager: IpBanManager | None = request.app.get(KEY_BAN_MANAGER) if (ban_manager := request.app.get(KEY_BAN_MANAGER)) is None:
if ban_manager is None:
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded") _LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
return await handler(request) return await handler(request)
ip_bans_lookup = ban_manager.ip_bans_lookup if ip_bans_lookup := ban_manager.ip_bans_lookup:
if ip_bans_lookup:
# Verify if IP is not banned # Verify if IP is not banned
ip_address_ = ip_address(request.remote) # type: ignore[arg-type] ip_address_ = ip_address(request.remote) # type: ignore[arg-type]
if ip_address_ in ip_bans_lookup: if ip_address_ in ip_bans_lookup:
@ -154,7 +161,7 @@ async def process_wrong_login(request: Request) -> None:
request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr] request.app[KEY_FAILED_LOGIN_ATTEMPTS][remote_addr]
>= request.app[KEY_LOGIN_THRESHOLD] >= request.app[KEY_LOGIN_THRESHOLD]
): ):
ban_manager: IpBanManager = request.app[KEY_BAN_MANAGER] ban_manager = request.app[KEY_BAN_MANAGER]
_LOGGER.warning("Banned IP %s for too many login attempts", remote_addr) _LOGGER.warning("Banned IP %s for too many login attempts", remote_addr)
await ban_manager.async_add_ban(remote_addr) await ban_manager.async_add_ban(remote_addr)
@ -180,9 +187,7 @@ def process_success_login(request: Request) -> None:
return return
remote_addr = ip_address(request.remote) # type: ignore[arg-type] remote_addr = ip_address(request.remote) # type: ignore[arg-type]
login_attempt_history: defaultdict[IPv4Address | IPv6Address, int] = app[ login_attempt_history = app[KEY_FAILED_LOGIN_ATTEMPTS]
KEY_FAILED_LOGIN_ATTEMPTS
]
if remote_addr in login_attempt_history and login_attempt_history[remote_addr] > 0: if remote_addr in login_attempt_history and login_attempt_history[remote_addr] > 0:
_LOGGER.debug( _LOGGER.debug(
"Login success, reset failed login attempts counter from %s", remote_addr "Login success, reset failed login attempts counter from %s", remote_addr

View File

@ -15,7 +15,6 @@ from homeassistant.components.http.ban import (
IP_BANS_FILE, IP_BANS_FILE,
KEY_BAN_MANAGER, KEY_BAN_MANAGER,
KEY_FAILED_LOGIN_ATTEMPTS, KEY_FAILED_LOGIN_ATTEMPTS,
IpBanManager,
process_success_login, process_success_login,
setup_bans, setup_bans,
) )
@ -215,7 +214,7 @@ async def test_access_from_supervisor_ip(
): ):
client = await aiohttp_client(app) client = await aiohttp_client(app)
manager: IpBanManager = app[KEY_BAN_MANAGER] manager = app[KEY_BAN_MANAGER]
with patch( with patch(
"homeassistant.components.hassio.HassIO.get_resolution_info", "homeassistant.components.hassio.HassIO.get_resolution_info",
@ -288,7 +287,7 @@ async def test_ip_bans_file_creation(
): ):
client = await aiohttp_client(app) client = await aiohttp_client(app)
manager: IpBanManager = app[KEY_BAN_MANAGER] manager = app[KEY_BAN_MANAGER]
m_open = mock_open() m_open = mock_open()
with patch("homeassistant.components.http.ban.open", m_open, create=True): with patch("homeassistant.components.http.ban.open", m_open, create=True):
@ -408,7 +407,7 @@ async def test_single_ban_file_entry(
setup_bans(hass, app, 2) setup_bans(hass, app, 2)
mock_real_ip(app)("200.201.202.204") mock_real_ip(app)("200.201.202.204")
manager: IpBanManager = app[KEY_BAN_MANAGER] manager = app[KEY_BAN_MANAGER]
m_open = mock_open() m_open = mock_open()
with patch("homeassistant.components.http.ban.open", m_open, create=True): with patch("homeassistant.components.http.ban.open", m_open, create=True):