Add missing type hints in http component (#50411)

This commit is contained in:
Ruslan Sayfutdinov 2021-05-10 22:30:47 +01:00 committed by GitHub
parent 85f758380a
commit ce15f28642
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 245 additions and 160 deletions

View File

@ -6,16 +6,18 @@ from ipaddress import ip_network
import logging import logging
import os import os
import ssl import ssl
from typing import Any, Optional, cast from typing import Any, Final, Optional, TypedDict, cast
from aiohttp import web from aiohttp import web
from aiohttp.web_exceptions import HTTPMovedPermanently from aiohttp.typedefs import StrOrURL
from aiohttp.web_exceptions import HTTPMovedPermanently, HTTPRedirection
import voluptuous as vol import voluptuous as vol
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
from homeassistant.core import Event, HomeAssistant from homeassistant.core import Event, HomeAssistant
from homeassistant.helpers import storage from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.setup import async_start_setup, async_when_setup_or_start from homeassistant.setup import async_start_setup, async_when_setup_or_start
import homeassistant.util as hass_util import homeassistant.util as hass_util
@ -29,44 +31,42 @@ from .forwarded import async_setup_forwarded
from .request_context import setup_request_context from .request_context import setup_request_context
from .security_filter import setup_security_filter from .security_filter import setup_security_filter
from .static import CACHE_HEADERS, CachingStaticResource from .static import CACHE_HEADERS, CachingStaticResource
from .view import HomeAssistantView # noqa: F401 from .view import HomeAssistantView
from .web_runner import HomeAssistantTCPSite from .web_runner import HomeAssistantTCPSite
# mypy: allow-untyped-defs, no-check-untyped-defs DOMAIN: Final = "http"
DOMAIN = "http" CONF_SERVER_HOST: Final = "server_host"
CONF_SERVER_PORT: Final = "server_port"
CONF_BASE_URL: Final = "base_url"
CONF_SSL_CERTIFICATE: Final = "ssl_certificate"
CONF_SSL_PEER_CERTIFICATE: Final = "ssl_peer_certificate"
CONF_SSL_KEY: Final = "ssl_key"
CONF_CORS_ORIGINS: Final = "cors_allowed_origins"
CONF_USE_X_FORWARDED_FOR: Final = "use_x_forwarded_for"
CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
CONF_SSL_PROFILE: Final = "ssl_profile"
CONF_SERVER_HOST = "server_host" SSL_MODERN: Final = "modern"
CONF_SERVER_PORT = "server_port" SSL_INTERMEDIATE: Final = "intermediate"
CONF_BASE_URL = "base_url"
CONF_SSL_CERTIFICATE = "ssl_certificate"
CONF_SSL_PEER_CERTIFICATE = "ssl_peer_certificate"
CONF_SSL_KEY = "ssl_key"
CONF_CORS_ORIGINS = "cors_allowed_origins"
CONF_USE_X_FORWARDED_FOR = "use_x_forwarded_for"
CONF_TRUSTED_PROXIES = "trusted_proxies"
CONF_LOGIN_ATTEMPTS_THRESHOLD = "login_attempts_threshold"
CONF_IP_BAN_ENABLED = "ip_ban_enabled"
CONF_SSL_PROFILE = "ssl_profile"
SSL_MODERN = "modern" _LOGGER: Final = logging.getLogger(__name__)
SSL_INTERMEDIATE = "intermediate"
_LOGGER = logging.getLogger(__name__) DEFAULT_DEVELOPMENT: Final = "0"
DEFAULT_DEVELOPMENT = "0"
# Cast to be able to load custom cards. # Cast to be able to load custom cards.
# My to be able to check url and version info. # My to be able to check url and version info.
DEFAULT_CORS = ["https://cast.home-assistant.io"] DEFAULT_CORS: Final[list[str]] = ["https://cast.home-assistant.io"]
NO_LOGIN_ATTEMPT_THRESHOLD = -1 NO_LOGIN_ATTEMPT_THRESHOLD: Final = -1
MAX_CLIENT_SIZE: int = 1024 ** 2 * 16 MAX_CLIENT_SIZE: Final = 1024 ** 2 * 16
STORAGE_KEY = DOMAIN STORAGE_KEY: Final = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION: Final = 1
SAVE_DELAY = 180 SAVE_DELAY: Final = 180
HTTP_SCHEMA = vol.All( HTTP_SCHEMA: Final = vol.All(
cv.deprecated(CONF_BASE_URL), cv.deprecated(CONF_BASE_URL),
vol.Schema( vol.Schema(
{ {
@ -96,7 +96,24 @@ HTTP_SCHEMA = vol.All(
), ),
) )
CONFIG_SCHEMA = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA) CONFIG_SCHEMA: Final = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
class ConfData(TypedDict, total=False):
"""Typed dict for config data."""
server_host: list[str]
server_port: int
base_url: str
ssl_certificate: str
ssl_peer_certificate: str
ssl_key: str
cors_allowed_origins: list[str]
use_x_forwarded_for: bool
trusted_proxies: list[str]
login_attempts_threshold: int
ip_ban_enabled: bool
ssl_profile: str
@bind_hass @bind_hass
@ -113,8 +130,8 @@ class ApiConfig:
self, self,
local_ip: str, local_ip: str,
host: str, host: str,
port: int | None = SERVER_PORT, port: int,
use_ssl: bool = False, use_ssl: bool,
) -> None: ) -> None:
"""Initialize a new API config object.""" """Initialize a new API config object."""
self.local_ip = local_ip self.local_ip = local_ip
@ -123,12 +140,12 @@ class ApiConfig:
self.use_ssl = use_ssl self.use_ssl = use_ssl
async def async_setup(hass, config): async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HTTP API and debug interface.""" """Set up the HTTP API and debug interface."""
conf = config.get(DOMAIN) conf: ConfData | None = config.get(DOMAIN)
if conf is None: if conf is None:
conf = HTTP_SCHEMA({}) conf = cast(ConfData, HTTP_SCHEMA({}))
server_host = conf.get(CONF_SERVER_HOST) server_host = conf.get(CONF_SERVER_HOST)
server_port = conf[CONF_SERVER_PORT] server_port = conf[CONF_SERVER_PORT]
@ -137,7 +154,7 @@ async def async_setup(hass, config):
ssl_key = conf.get(CONF_SSL_KEY) ssl_key = conf.get(CONF_SSL_KEY)
cors_origins = conf[CONF_CORS_ORIGINS] cors_origins = conf[CONF_CORS_ORIGINS]
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False) use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES, []) trusted_proxies = conf.get(CONF_TRUSTED_PROXIES) or []
is_ban_enabled = conf[CONF_IP_BAN_ENABLED] is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD] login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
ssl_profile = conf[CONF_SSL_PROFILE] ssl_profile = conf[CONF_SSL_PROFILE]
@ -165,6 +182,8 @@ async def async_setup(hass, config):
"""Start the server.""" """Start the server."""
with async_start_setup(hass, ["http"]): with async_start_setup(hass, ["http"]):
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
# We already checked it's not None.
assert conf is not None
await start_http_server_and_save_config(hass, dict(conf), server) await start_http_server_and_save_config(hass, dict(conf), server)
async_when_setup_or_start(hass, "frontend", start_server) async_when_setup_or_start(hass, "frontend", start_server)
@ -190,19 +209,19 @@ class HomeAssistantHTTP:
def __init__( def __init__(
self, self,
hass, hass: HomeAssistant,
ssl_certificate, ssl_certificate: str | None,
ssl_peer_certificate, ssl_peer_certificate: str | None,
ssl_key, ssl_key: str | None,
server_host, server_host: list[str] | None,
server_port, server_port: int,
cors_origins, cors_origins: list[str],
use_x_forwarded_for, use_x_forwarded_for: bool,
trusted_proxies, trusted_proxies: list[str],
login_threshold, login_threshold: int,
is_ban_enabled, is_ban_enabled: bool,
ssl_profile, ssl_profile: str,
): ) -> None:
"""Initialize the HTTP Home Assistant server.""" """Initialize the HTTP Home Assistant server."""
app = self.app = web.Application( app = self.app = web.Application(
middlewares=[], client_max_size=MAX_CLIENT_SIZE middlewares=[], client_max_size=MAX_CLIENT_SIZE
@ -237,10 +256,10 @@ class HomeAssistantHTTP:
self.is_ban_enabled = is_ban_enabled self.is_ban_enabled = is_ban_enabled
self.ssl_profile = ssl_profile self.ssl_profile = ssl_profile
self._handler = None self._handler = None
self.runner = None self.runner: web.AppRunner | None = None
self.site = None self.site: HomeAssistantTCPSite | None = None
def register_view(self, view): def register_view(self, view: HomeAssistantView) -> None:
"""Register a view with the WSGI server. """Register a view with the WSGI server.
The view argument must be a class that inherits from HomeAssistantView. The view argument must be a class that inherits from HomeAssistantView.
@ -261,7 +280,13 @@ class HomeAssistantHTTP:
view.register(self.app, self.app.router) view.register(self.app, self.app.router)
def register_redirect(self, url, redirect_to, *, redirect_exc=HTTPMovedPermanently): def register_redirect(
self,
url: str,
redirect_to: StrOrURL,
*,
redirect_exc: type[HTTPRedirection] = HTTPMovedPermanently,
) -> None:
"""Register a redirect with the server. """Register a redirect with the server.
If given this must be either a string or callable. In case of a If given this must be either a string or callable. In case of a
@ -271,38 +296,39 @@ class HomeAssistantHTTP:
rule syntax. rule syntax.
""" """
async def redirect(request): async def redirect(request: web.Request) -> web.StreamResponse:
"""Redirect to location.""" """Redirect to location."""
raise redirect_exc(redirect_to) # Should be instance of aiohttp.web_exceptions._HTTPMove.
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
self.app.router.add_route("GET", url, redirect) self.app.router.add_route("GET", url, redirect)
def register_static_path(self, url_path, path, cache_headers=True): def register_static_path(
self, url_path: str, path: str, cache_headers: bool = True
) -> web.FileResponse | None:
"""Register a folder or file to serve as a static path.""" """Register a folder or file to serve as a static path."""
if os.path.isdir(path): if os.path.isdir(path):
if cache_headers: if cache_headers:
resource = CachingStaticResource resource: type[
CachingStaticResource | web.StaticResource
] = CachingStaticResource
else: else:
resource = web.StaticResource resource = web.StaticResource
self.app.router.register_resource(resource(url_path, path)) self.app.router.register_resource(resource(url_path, path))
return return None
if cache_headers: async def serve_file(request: web.Request) -> web.FileResponse:
"""Serve file from disk."""
async def serve_file(request): if cache_headers:
"""Serve file from disk."""
return web.FileResponse(path, headers=CACHE_HEADERS) return web.FileResponse(path, headers=CACHE_HEADERS)
return web.FileResponse(path)
else:
async def serve_file(request):
"""Serve file from disk."""
return web.FileResponse(path)
self.app.router.add_route("GET", url_path, serve_file) self.app.router.add_route("GET", url_path, serve_file)
return None
async def start(self): async def start(self) -> None:
"""Start the aiohttp server.""" """Start the aiohttp server."""
context: ssl.SSLContext | None
if self.ssl_certificate: if self.ssl_certificate:
try: try:
if self.ssl_profile == SSL_INTERMEDIATE: if self.ssl_profile == SSL_INTERMEDIATE:
@ -334,7 +360,7 @@ class HomeAssistantHTTP:
# This will now raise a RunTimeError. # This will now raise a RunTimeError.
# To work around this we now prevent the router from getting frozen # To work around this we now prevent the router from getting frozen
# pylint: disable=protected-access # pylint: disable=protected-access
self.app._router.freeze = lambda: None self.app._router.freeze = lambda: None # type: ignore[assignment]
self.runner = web.AppRunner(self.app) self.runner = web.AppRunner(self.app)
await self.runner.setup() await self.runner.setup()
@ -351,17 +377,19 @@ class HomeAssistantHTTP:
_LOGGER.info("Now listening on port %d", self.server_port) _LOGGER.info("Now listening on port %d", self.server_port)
async def stop(self): async def stop(self) -> None:
"""Stop the aiohttp server.""" """Stop the aiohttp server."""
await self.site.stop() if self.site is not None:
await self.runner.cleanup() await self.site.stop()
if self.runner is not None:
await self.runner.cleanup()
async def start_http_server_and_save_config( async def start_http_server_and_save_config(
hass: HomeAssistant, conf: dict, server: HomeAssistantHTTP hass: HomeAssistant, conf: dict, server: HomeAssistantHTTP
) -> None: ) -> None:
"""Startup the http server and save the config.""" """Startup the http server and save the config."""
await server.start() # type: ignore await server.start()
# If we are set up successful, we store the HTTP settings for safe mode. # If we are set up successful, we store the HTTP settings for safe mode.
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY) store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)

View File

@ -1,28 +1,33 @@
"""Authentication for HTTP component.""" """Authentication for HTTP component."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from datetime import timedelta
import logging import logging
import secrets import secrets
from typing import Final
from urllib.parse import unquote from urllib.parse import unquote
from aiohttp import hdrs from aiohttp import hdrs
from aiohttp.web import middleware from aiohttp.web import Application, Request, StreamResponse, middleware
import jwt import jwt
from homeassistant.core import callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_API_PASSWORD = "api_password" DATA_API_PASSWORD: Final = "api_password"
DATA_SIGN_SECRET = "http.auth.sign_secret" DATA_SIGN_SECRET: Final = "http.auth.sign_secret"
SIGN_QUERY_PARAM = "authSig" SIGN_QUERY_PARAM: Final = "authSig"
@callback @callback
def async_sign_path(hass, refresh_token_id, path, expiration): def async_sign_path(
hass: HomeAssistant, refresh_token_id: str, path: str, expiration: timedelta
) -> str:
"""Sign a path for temporary access without auth header.""" """Sign a path for temporary access without auth header."""
secret = hass.data.get(DATA_SIGN_SECRET) secret = hass.data.get(DATA_SIGN_SECRET)
@ -44,17 +49,19 @@ def async_sign_path(hass, refresh_token_id, path, expiration):
@callback @callback
def setup_auth(hass, app): def setup_auth(hass: HomeAssistant, app: Application) -> None:
"""Create auth middleware for the app.""" """Create auth middleware for the app."""
async def async_validate_auth_header(request): async def async_validate_auth_header(request: Request) -> bool:
""" """
Test authorization header against access token. Test authorization header against access token.
Basic auth_type is legacy code, should be removed with api_password. Basic auth_type is legacy code, should be removed with api_password.
""" """
try: try:
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION).split(" ", 1) auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION, "").split(
" ", 1
)
except ValueError: except ValueError:
# If no space in authorization header # If no space in authorization header
return False return False
@ -71,7 +78,7 @@ def setup_auth(hass, app):
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
return True return True
async def async_validate_signed_request(request): async def async_validate_signed_request(request: Request) -> bool:
"""Validate a signed request.""" """Validate a signed request."""
secret = hass.data.get(DATA_SIGN_SECRET) secret = hass.data.get(DATA_SIGN_SECRET)
@ -103,7 +110,9 @@ def setup_auth(hass, app):
return True return True
@middleware @middleware
async def auth_middleware(request, handler): async def auth_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Authenticate as middleware.""" """Authenticate as middleware."""
authenticated = False authenticated = False

View File

@ -2,13 +2,15 @@
from __future__ import annotations from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Awaitable, Callable
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
from socket import gethostbyaddr, herror from socket import gethostbyaddr, herror
from typing import Any, Final
from aiohttp.web import middleware from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol import voluptuous as vol
@ -19,33 +21,33 @@ from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util import dt as dt_util, yaml from homeassistant.util import dt as dt_util, yaml
# mypy: allow-untyped-defs, no-check-untyped-defs from .view import HomeAssistantView
_LOGGER = logging.getLogger(__name__) _LOGGER: Final = logging.getLogger(__name__)
KEY_BANNED_IPS = "ha_banned_ips" KEY_BANNED_IPS: Final = "ha_banned_ips"
KEY_FAILED_LOGIN_ATTEMPTS = "ha_failed_login_attempts" KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
KEY_LOGIN_THRESHOLD = "ha_login_threshold" KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
NOTIFICATION_ID_BAN = "ip-ban" NOTIFICATION_ID_BAN: Final = "ip-ban"
NOTIFICATION_ID_LOGIN = "http-login" NOTIFICATION_ID_LOGIN: Final = "http-login"
IP_BANS_FILE = "ip_bans.yaml" IP_BANS_FILE: Final = "ip_bans.yaml"
ATTR_BANNED_AT = "banned_at" ATTR_BANNED_AT: Final = "banned_at"
SCHEMA_IP_BAN_ENTRY = vol.Schema( SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
{vol.Optional("banned_at"): vol.Any(None, cv.datetime)} {vol.Optional("banned_at"): vol.Any(None, cv.datetime)}
) )
@callback @callback
def setup_bans(hass, app, login_threshold): def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
"""Create IP Ban middleware for the app.""" """Create IP Ban middleware for the app."""
app.middlewares.append(ban_middleware) app.middlewares.append(ban_middleware)
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int) app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
app[KEY_LOGIN_THRESHOLD] = login_threshold app[KEY_LOGIN_THRESHOLD] = login_threshold
async def ban_startup(app): async def ban_startup(app: Application) -> None:
"""Initialize bans when app starts up.""" """Initialize bans when app starts up."""
app[KEY_BANNED_IPS] = await async_load_ip_bans_config( app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
hass, hass.config.path(IP_BANS_FILE) hass, hass.config.path(IP_BANS_FILE)
@ -55,7 +57,9 @@ def setup_bans(hass, app, login_threshold):
@middleware @middleware
async def ban_middleware(request, handler): async def ban_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""IP Ban middleware.""" """IP Ban middleware."""
if KEY_BANNED_IPS not in request.app: if KEY_BANNED_IPS not in request.app:
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded") _LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
@ -77,10 +81,14 @@ async def ban_middleware(request, handler):
raise raise
def log_invalid_auth(func): def log_invalid_auth(
func: Callable[..., Awaitable[StreamResponse]]
) -> Callable[..., Awaitable[StreamResponse]]:
"""Decorate function to handle invalid auth or failed login attempts.""" """Decorate function to handle invalid auth or failed login attempts."""
async def handle_req(view, request, *args, **kwargs): async def handle_req(
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
) -> StreamResponse:
"""Try to log failed login attempts if response status >= 400.""" """Try to log failed login attempts if response status >= 400."""
resp = await func(view, request, *args, **kwargs) resp = await func(view, request, *args, **kwargs)
if resp.status >= HTTP_BAD_REQUEST: if resp.status >= HTTP_BAD_REQUEST:
@ -90,7 +98,7 @@ def log_invalid_auth(func):
return handle_req return handle_req
async def process_wrong_login(request): async def process_wrong_login(request: Request) -> None:
"""Process a wrong login attempt. """Process a wrong login attempt.
Increase failed login attempts counter for remote IP address. Increase failed login attempts counter for remote IP address.
@ -152,7 +160,7 @@ async def process_wrong_login(request):
) )
async def process_success_login(request): async def process_success_login(request: Request) -> None:
"""Process a success login attempt. """Process a success login attempt.
Reset failed login attempts counter for remote IP address. Reset failed login attempts counter for remote IP address.

View File

@ -1,5 +1,7 @@
"""HTTP specific constants.""" """HTTP specific constants."""
KEY_AUTHENTICATED = "ha_authenticated" from typing import Final
KEY_HASS = "hass"
KEY_HASS_USER = "hass_user" KEY_AUTHENTICATED: Final = "ha_authenticated"
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id" KEY_HASS: Final = "hass"
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"

View File

@ -1,24 +1,33 @@
"""Provide CORS support for the HTTP component.""" """Provide CORS support for the HTTP component."""
from __future__ import annotations
from typing import Final
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
from aiohttp.web_urldispatcher import Resource, ResourceRoute, StaticResource from aiohttp.web import Application
from aiohttp.web_urldispatcher import (
AbstractResource,
AbstractRoute,
Resource,
ResourceRoute,
StaticResource,
)
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
from homeassistant.core import callback from homeassistant.core import callback
# mypy: allow-untyped-defs, no-check-untyped-defs ALLOWED_CORS_HEADERS: Final[list[str]] = [
ALLOWED_CORS_HEADERS = [
ORIGIN, ORIGIN,
ACCEPT, ACCEPT,
HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_X_REQUESTED_WITH,
CONTENT_TYPE, CONTENT_TYPE,
AUTHORIZATION, AUTHORIZATION,
] ]
VALID_CORS_TYPES = (Resource, ResourceRoute, StaticResource) VALID_CORS_TYPES: Final = (Resource, ResourceRoute, StaticResource)
@callback @callback
def setup_cors(app, origins): def setup_cors(app: Application, origins: list[str]) -> None:
"""Set up CORS.""" """Set up CORS."""
# This import should remain here. That way the HTTP integration can always # This import should remain here. That way the HTTP integration can always
# be imported by other integrations without it's requirements being installed. # be imported by other integrations without it's requirements being installed.
@ -37,9 +46,12 @@ def setup_cors(app, origins):
cors_added = set() cors_added = set()
def _allow_cors(route, config=None): def _allow_cors(
route: AbstractRoute | AbstractResource,
config: dict[str, aiohttp_cors.ResourceOptions] | None = None,
) -> None:
"""Allow CORS on a route.""" """Allow CORS on a route."""
if hasattr(route, "resource"): if isinstance(route, AbstractRoute):
path = route.resource path = route.resource
else: else:
path = route path = route
@ -47,16 +59,16 @@ def setup_cors(app, origins):
if not isinstance(path, VALID_CORS_TYPES): if not isinstance(path, VALID_CORS_TYPES):
return return
path = path.canonical path_str = path.canonical
if path.startswith("/api/hassio_ingress/"): if path_str.startswith("/api/hassio_ingress/"):
return return
if path in cors_added: if path_str in cors_added:
return return
cors.add(route, config) cors.add(route, config)
cors_added.add(path) cors_added.add(path_str)
app["allow_cors"] = lambda route: _allow_cors( app["allow_cors"] = lambda route: _allow_cors(
route, route,
@ -70,7 +82,7 @@ def setup_cors(app, origins):
if not origins: if not origins:
return return
async def cors_startup(app): async def cors_startup(app: Application) -> None:
"""Initialize CORS when app starts up.""" """Initialize CORS when app starts up."""
for resource in list(app.router.resources()): for resource in list(app.router.resources()):
_allow_cors(resource) _allow_cors(resource)

View File

@ -1,19 +1,20 @@
"""Middleware to handle forwarded data by a reverse proxy.""" """Middleware to handle forwarded data by a reverse proxy."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from ipaddress import ip_address from ipaddress import ip_address
import logging import logging
from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO
from aiohttp.web import HTTPBadRequest, middleware from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
from homeassistant.core import callback from homeassistant.core import callback
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs
@callback @callback
def async_setup_forwarded(app, trusted_proxies): def async_setup_forwarded(app: Application, trusted_proxies: list[str]) -> None:
"""Create forwarded middleware for the app. """Create forwarded middleware for the app.
Process IP addresses, proto and host information in the forwarded for headers. Process IP addresses, proto and host information in the forwarded for headers.
@ -60,17 +61,20 @@ def async_setup_forwarded(app, trusted_proxies):
""" """
@middleware @middleware
async def forwarded_middleware(request, handler): async def forwarded_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Process forwarded data by a reverse proxy.""" """Process forwarded data by a reverse proxy."""
overrides = {} overrides: dict[str, str] = {}
# Handle X-Forwarded-For # Handle X-Forwarded-For
forwarded_for_headers = request.headers.getall(X_FORWARDED_FOR, []) forwarded_for_headers: list[str] = request.headers.getall(X_FORWARDED_FOR, [])
if not forwarded_for_headers: if not forwarded_for_headers:
# No forwarding headers, continue as normal # No forwarding headers, continue as normal
return await handler(request) return await handler(request)
# Ensure the IP of the connected peer is trusted # Ensure the IP of the connected peer is trusted
assert request.transport is not None
connected_ip = ip_address(request.transport.get_extra_info("peername")[0]) connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
if not any(connected_ip in trusted_proxy for trusted_proxy in trusted_proxies): if not any(connected_ip in trusted_proxy for trusted_proxy in trusted_proxies):
_LOGGER.warning( _LOGGER.warning(
@ -111,7 +115,9 @@ def async_setup_forwarded(app, trusted_proxies):
overrides["remote"] = str(forwarded_for[-1]) overrides["remote"] = str(forwarded_for[-1])
# Handle X-Forwarded-Proto # Handle X-Forwarded-Proto
forwarded_proto_headers = request.headers.getall(X_FORWARDED_PROTO, []) forwarded_proto_headers: list[str] = request.headers.getall(
X_FORWARDED_PROTO, []
)
if forwarded_proto_headers: if forwarded_proto_headers:
if len(forwarded_proto_headers) > 1: if len(forwarded_proto_headers) > 1:
_LOGGER.error( _LOGGER.error(
@ -151,7 +157,7 @@ def async_setup_forwarded(app, trusted_proxies):
overrides["scheme"] = forwarded_proto[forwarded_for_index] overrides["scheme"] = forwarded_proto[forwarded_for_index]
# Handle X-Forwarded-Host # Handle X-Forwarded-Host
forwarded_host_headers = request.headers.getall(X_FORWARDED_HOST, []) forwarded_host_headers: list[str] = request.headers.getall(X_FORWARDED_HOST, [])
if forwarded_host_headers: if forwarded_host_headers:
# Multiple X-Forwarded-Host headers # Multiple X-Forwarded-Host headers
if len(forwarded_host_headers) > 1: if len(forwarded_host_headers) > 1:
@ -168,7 +174,7 @@ def async_setup_forwarded(app, trusted_proxies):
overrides["host"] = forwarded_host overrides["host"] = forwarded_host
# Done, create a new request based on gathered data. # Done, create a new request based on gathered data.
request = request.clone(**overrides) request = request.clone(**overrides) # type: ignore[arg-type]
return await handler(request) return await handler(request)
app.middlewares.append(forwarded_middleware) app.middlewares.append(forwarded_middleware)

View File

@ -1,18 +1,24 @@
"""Middleware to set the request context.""" """Middleware to set the request context."""
from __future__ import annotations
from aiohttp.web import middleware from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from aiohttp.web import Application, Request, StreamResponse, middleware
from homeassistant.core import callback from homeassistant.core import callback
# mypy: allow-untyped-defs
@callback @callback
def setup_request_context(app, context): def setup_request_context(
app: Application, context: ContextVar[Request | None]
) -> None:
"""Create request context middleware for the app.""" """Create request context middleware for the app."""
@middleware @middleware
async def request_context_middleware(request, handler): async def request_context_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Request context middleware.""" """Request context middleware."""
context.set(request) context.set(request)
return await handler(request) return await handler(request)

View File

@ -1,17 +1,19 @@
"""Middleware to add some basic security filtering to requests.""" """Middleware to add some basic security filtering to requests."""
from __future__ import annotations
from collections.abc import Awaitable, Callable
import logging import logging
import re import re
from typing import Final
from aiohttp.web import HTTPBadRequest, middleware from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
from homeassistant.core import callback from homeassistant.core import callback
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs
# fmt: off # fmt: off
FILTERS = re.compile( FILTERS: Final = re.compile(
r"(?:" r"(?:"
# Common exploits # Common exploits
@ -34,12 +36,14 @@ FILTERS = re.compile(
@callback @callback
def setup_security_filter(app): def setup_security_filter(app: Application) -> None:
"""Create security filter middleware for the app.""" """Create security filter middleware for the app."""
@middleware @middleware
async def security_filter_middleware(request, handler): async def security_filter_middleware(
"""Process request and block commonly known exploit attempts.""" request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Process request and tblock commonly known exploit attempts."""
if FILTERS.search(request.path): if FILTERS.search(request.path):
_LOGGER.warning( _LOGGER.warning(
"Filtered a potential harmful request to: %s", request.raw_path "Filtered a potential harmful request to: %s", request.raw_path

View File

@ -1,21 +1,25 @@
"""Static file handling for HTTP component.""" """Static file handling for HTTP component."""
from __future__ import annotations
from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Final
from aiohttp import hdrs from aiohttp import hdrs
from aiohttp.web import FileResponse from aiohttp.web import FileResponse, Request, StreamResponse
from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound
from aiohttp.web_urldispatcher import StaticResource from aiohttp.web_urldispatcher import StaticResource
# mypy: allow-untyped-defs CACHE_TIME: Final = 31 * 86400 # = 1 month
CACHE_HEADERS: Final[Mapping[str, str]] = {
CACHE_TIME = 31 * 86400 # = 1 month hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"
CACHE_HEADERS = {hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"} }
class CachingStaticResource(StaticResource): class CachingStaticResource(StaticResource):
"""Static Resource handler that will add cache headers.""" """Static Resource handler that will add cache headers."""
async def _handle(self, request): async def _handle(self, request: Request) -> StreamResponse:
rel_url = request.match_info["filename"] rel_url = request.match_info["filename"]
try: try:
filename = Path(rel_url) filename = Path(rel_url)
@ -42,7 +46,6 @@ class CachingStaticResource(StaticResource):
return FileResponse( return FileResponse(
filepath, filepath,
chunk_size=self._chunk_size, chunk_size=self._chunk_size,
# type ignore: https://github.com/aio-libs/aiohttp/pull/3976 headers=CACHE_HEADERS,
headers=CACHE_HEADERS, # type: ignore
) )
raise HTTPNotFound raise HTTPNotFound

View File

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable
import json import json
import logging import logging
from typing import Any, Callable from typing import Any
from aiohttp import web from aiohttp import web
from aiohttp.typedefs import LooseHeaders from aiohttp.typedefs import LooseHeaders
@ -13,6 +14,7 @@ from aiohttp.web_exceptions import (
HTTPInternalServerError, HTTPInternalServerError,
HTTPUnauthorized, HTTPUnauthorized,
) )
from aiohttp.web_urldispatcher import AbstractRoute
import voluptuous as vol import voluptuous as vol
from homeassistant import exceptions from homeassistant import exceptions
@ -81,7 +83,7 @@ class HomeAssistantView:
"""Register the view with a router.""" """Register the view with a router."""
assert self.url is not None, "No url set for view" assert self.url is not None, "No url set for view"
urls = [self.url] + self.extra_urls urls = [self.url] + self.extra_urls
routes = [] routes: list[AbstractRoute] = []
for method in ("get", "post", "delete", "put", "patch", "head", "options"): for method in ("get", "post", "delete", "put", "patch", "head", "options"):
handler = getattr(self, method, None) handler = getattr(self, method, None)
@ -101,7 +103,9 @@ class HomeAssistantView:
app["allow_cors"](route) app["allow_cors"](route)
def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable: def request_handler_factory(
view: HomeAssistantView, handler: Callable
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
"""Wrap the handler classes.""" """Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback( assert asyncio.iscoroutinefunction(handler) or is_callback(
handler handler

View File

@ -23,7 +23,7 @@ class HomeAssistantTCPSite(web.BaseSite):
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl") __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")
def __init__( # noqa: D107 def __init__(
self, self,
runner: web.BaseRunner, runner: web.BaseRunner,
host: None | str | list[str], host: None | str | list[str],
@ -35,6 +35,7 @@ class HomeAssistantTCPSite(web.BaseSite):
reuse_address: bool | None = None, reuse_address: bool | None = None,
reuse_port: bool | None = None, reuse_port: bool | None = None,
) -> None: ) -> None:
"""Initialize HomeAssistantTCPSite."""
super().__init__( super().__init__(
runner, runner,
shutdown_timeout=shutdown_timeout, shutdown_timeout=shutdown_timeout,
@ -47,12 +48,14 @@ class HomeAssistantTCPSite(web.BaseSite):
self._reuse_port = reuse_port self._reuse_port = reuse_port
@property @property
def name(self) -> str: # noqa: D102 def name(self) -> str:
"""Return server URL."""
scheme = "https" if self._ssl_context else "http" scheme = "https" if self._ssl_context else "http"
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0" host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
return str(URL.build(scheme=scheme, host=host, port=self._port)) return str(URL.build(scheme=scheme, host=host, port=self._port))
async def start(self) -> None: # noqa: D102 async def start(self) -> None:
"""Start server."""
await super().start() await super().start()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
server = self._runner.server server = self._runner.server

View File

@ -593,7 +593,7 @@ SERVICE_TOGGLE_COVER_TILT = "toggle_cover_tilt"
SERVICE_SELECT_OPTION = "select_option" SERVICE_SELECT_OPTION = "select_option"
# #### API / REMOTE #### # #### API / REMOTE ####
SERVER_PORT = 8123 SERVER_PORT: Final = 8123
URL_ROOT = "/" URL_ROOT = "/"
URL_API = "/api/" URL_API = "/api/"

View File

@ -334,7 +334,7 @@ def async_register_implementation(
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get( if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
DATA_VIEW_REGISTERED, False DATA_VIEW_REGISTERED, False
): ):
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore hass.http.register_view(OAuth2AuthorizeCallbackView())
hass.data[DATA_VIEW_REGISTERED] = True hass.data[DATA_VIEW_REGISTERED] = True
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}) implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})