mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 08:47:10 +00:00
Add missing type hints in http component (#50411)
This commit is contained in:
parent
85f758380a
commit
ce15f28642
@ -6,16 +6,18 @@ from ipaddress import ip_network
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Final, Optional, TypedDict, cast
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.web_exceptions import HTTPMovedPermanently
|
from aiohttp.typedefs import StrOrURL
|
||||||
|
from aiohttp.web_exceptions import HTTPMovedPermanently, HTTPRedirection
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, SERVER_PORT
|
||||||
from homeassistant.core import Event, HomeAssistant
|
from homeassistant.core import Event, HomeAssistant
|
||||||
from homeassistant.helpers import storage
|
from homeassistant.helpers import storage
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.setup import async_start_setup, async_when_setup_or_start
|
from homeassistant.setup import async_start_setup, async_when_setup_or_start
|
||||||
import homeassistant.util as hass_util
|
import homeassistant.util as hass_util
|
||||||
@ -29,44 +31,42 @@ from .forwarded import async_setup_forwarded
|
|||||||
from .request_context import setup_request_context
|
from .request_context import setup_request_context
|
||||||
from .security_filter import setup_security_filter
|
from .security_filter import setup_security_filter
|
||||||
from .static import CACHE_HEADERS, CachingStaticResource
|
from .static import CACHE_HEADERS, CachingStaticResource
|
||||||
from .view import HomeAssistantView # noqa: F401
|
from .view import HomeAssistantView
|
||||||
from .web_runner import HomeAssistantTCPSite
|
from .web_runner import HomeAssistantTCPSite
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
DOMAIN: Final = "http"
|
||||||
|
|
||||||
DOMAIN = "http"
|
CONF_SERVER_HOST: Final = "server_host"
|
||||||
|
CONF_SERVER_PORT: Final = "server_port"
|
||||||
|
CONF_BASE_URL: Final = "base_url"
|
||||||
|
CONF_SSL_CERTIFICATE: Final = "ssl_certificate"
|
||||||
|
CONF_SSL_PEER_CERTIFICATE: Final = "ssl_peer_certificate"
|
||||||
|
CONF_SSL_KEY: Final = "ssl_key"
|
||||||
|
CONF_CORS_ORIGINS: Final = "cors_allowed_origins"
|
||||||
|
CONF_USE_X_FORWARDED_FOR: Final = "use_x_forwarded_for"
|
||||||
|
CONF_TRUSTED_PROXIES: Final = "trusted_proxies"
|
||||||
|
CONF_LOGIN_ATTEMPTS_THRESHOLD: Final = "login_attempts_threshold"
|
||||||
|
CONF_IP_BAN_ENABLED: Final = "ip_ban_enabled"
|
||||||
|
CONF_SSL_PROFILE: Final = "ssl_profile"
|
||||||
|
|
||||||
CONF_SERVER_HOST = "server_host"
|
SSL_MODERN: Final = "modern"
|
||||||
CONF_SERVER_PORT = "server_port"
|
SSL_INTERMEDIATE: Final = "intermediate"
|
||||||
CONF_BASE_URL = "base_url"
|
|
||||||
CONF_SSL_CERTIFICATE = "ssl_certificate"
|
|
||||||
CONF_SSL_PEER_CERTIFICATE = "ssl_peer_certificate"
|
|
||||||
CONF_SSL_KEY = "ssl_key"
|
|
||||||
CONF_CORS_ORIGINS = "cors_allowed_origins"
|
|
||||||
CONF_USE_X_FORWARDED_FOR = "use_x_forwarded_for"
|
|
||||||
CONF_TRUSTED_PROXIES = "trusted_proxies"
|
|
||||||
CONF_LOGIN_ATTEMPTS_THRESHOLD = "login_attempts_threshold"
|
|
||||||
CONF_IP_BAN_ENABLED = "ip_ban_enabled"
|
|
||||||
CONF_SSL_PROFILE = "ssl_profile"
|
|
||||||
|
|
||||||
SSL_MODERN = "modern"
|
_LOGGER: Final = logging.getLogger(__name__)
|
||||||
SSL_INTERMEDIATE = "intermediate"
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
DEFAULT_DEVELOPMENT: Final = "0"
|
||||||
|
|
||||||
DEFAULT_DEVELOPMENT = "0"
|
|
||||||
# Cast to be able to load custom cards.
|
# Cast to be able to load custom cards.
|
||||||
# My to be able to check url and version info.
|
# My to be able to check url and version info.
|
||||||
DEFAULT_CORS = ["https://cast.home-assistant.io"]
|
DEFAULT_CORS: Final[list[str]] = ["https://cast.home-assistant.io"]
|
||||||
NO_LOGIN_ATTEMPT_THRESHOLD = -1
|
NO_LOGIN_ATTEMPT_THRESHOLD: Final = -1
|
||||||
|
|
||||||
MAX_CLIENT_SIZE: int = 1024 ** 2 * 16
|
MAX_CLIENT_SIZE: Final = 1024 ** 2 * 16
|
||||||
|
|
||||||
STORAGE_KEY = DOMAIN
|
STORAGE_KEY: Final = DOMAIN
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION: Final = 1
|
||||||
SAVE_DELAY = 180
|
SAVE_DELAY: Final = 180
|
||||||
|
|
||||||
HTTP_SCHEMA = vol.All(
|
HTTP_SCHEMA: Final = vol.All(
|
||||||
cv.deprecated(CONF_BASE_URL),
|
cv.deprecated(CONF_BASE_URL),
|
||||||
vol.Schema(
|
vol.Schema(
|
||||||
{
|
{
|
||||||
@ -96,7 +96,24 @@ HTTP_SCHEMA = vol.All(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
CONFIG_SCHEMA = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
|
CONFIG_SCHEMA: Final = vol.Schema({DOMAIN: HTTP_SCHEMA}, extra=vol.ALLOW_EXTRA)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfData(TypedDict, total=False):
|
||||||
|
"""Typed dict for config data."""
|
||||||
|
|
||||||
|
server_host: list[str]
|
||||||
|
server_port: int
|
||||||
|
base_url: str
|
||||||
|
ssl_certificate: str
|
||||||
|
ssl_peer_certificate: str
|
||||||
|
ssl_key: str
|
||||||
|
cors_allowed_origins: list[str]
|
||||||
|
use_x_forwarded_for: bool
|
||||||
|
trusted_proxies: list[str]
|
||||||
|
login_attempts_threshold: int
|
||||||
|
ip_ban_enabled: bool
|
||||||
|
ssl_profile: str
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
@ -113,8 +130,8 @@ class ApiConfig:
|
|||||||
self,
|
self,
|
||||||
local_ip: str,
|
local_ip: str,
|
||||||
host: str,
|
host: str,
|
||||||
port: int | None = SERVER_PORT,
|
port: int,
|
||||||
use_ssl: bool = False,
|
use_ssl: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a new API config object."""
|
"""Initialize a new API config object."""
|
||||||
self.local_ip = local_ip
|
self.local_ip = local_ip
|
||||||
@ -123,12 +140,12 @@ class ApiConfig:
|
|||||||
self.use_ssl = use_ssl
|
self.use_ssl = use_ssl
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up the HTTP API and debug interface."""
|
"""Set up the HTTP API and debug interface."""
|
||||||
conf = config.get(DOMAIN)
|
conf: ConfData | None = config.get(DOMAIN)
|
||||||
|
|
||||||
if conf is None:
|
if conf is None:
|
||||||
conf = HTTP_SCHEMA({})
|
conf = cast(ConfData, HTTP_SCHEMA({}))
|
||||||
|
|
||||||
server_host = conf.get(CONF_SERVER_HOST)
|
server_host = conf.get(CONF_SERVER_HOST)
|
||||||
server_port = conf[CONF_SERVER_PORT]
|
server_port = conf[CONF_SERVER_PORT]
|
||||||
@ -137,7 +154,7 @@ async def async_setup(hass, config):
|
|||||||
ssl_key = conf.get(CONF_SSL_KEY)
|
ssl_key = conf.get(CONF_SSL_KEY)
|
||||||
cors_origins = conf[CONF_CORS_ORIGINS]
|
cors_origins = conf[CONF_CORS_ORIGINS]
|
||||||
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
|
use_x_forwarded_for = conf.get(CONF_USE_X_FORWARDED_FOR, False)
|
||||||
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES, [])
|
trusted_proxies = conf.get(CONF_TRUSTED_PROXIES) or []
|
||||||
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
|
is_ban_enabled = conf[CONF_IP_BAN_ENABLED]
|
||||||
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
|
login_threshold = conf[CONF_LOGIN_ATTEMPTS_THRESHOLD]
|
||||||
ssl_profile = conf[CONF_SSL_PROFILE]
|
ssl_profile = conf[CONF_SSL_PROFILE]
|
||||||
@ -165,6 +182,8 @@ async def async_setup(hass, config):
|
|||||||
"""Start the server."""
|
"""Start the server."""
|
||||||
with async_start_setup(hass, ["http"]):
|
with async_start_setup(hass, ["http"]):
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_server)
|
||||||
|
# We already checked it's not None.
|
||||||
|
assert conf is not None
|
||||||
await start_http_server_and_save_config(hass, dict(conf), server)
|
await start_http_server_and_save_config(hass, dict(conf), server)
|
||||||
|
|
||||||
async_when_setup_or_start(hass, "frontend", start_server)
|
async_when_setup_or_start(hass, "frontend", start_server)
|
||||||
@ -190,19 +209,19 @@ class HomeAssistantHTTP:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass,
|
hass: HomeAssistant,
|
||||||
ssl_certificate,
|
ssl_certificate: str | None,
|
||||||
ssl_peer_certificate,
|
ssl_peer_certificate: str | None,
|
||||||
ssl_key,
|
ssl_key: str | None,
|
||||||
server_host,
|
server_host: list[str] | None,
|
||||||
server_port,
|
server_port: int,
|
||||||
cors_origins,
|
cors_origins: list[str],
|
||||||
use_x_forwarded_for,
|
use_x_forwarded_for: bool,
|
||||||
trusted_proxies,
|
trusted_proxies: list[str],
|
||||||
login_threshold,
|
login_threshold: int,
|
||||||
is_ban_enabled,
|
is_ban_enabled: bool,
|
||||||
ssl_profile,
|
ssl_profile: str,
|
||||||
):
|
) -> None:
|
||||||
"""Initialize the HTTP Home Assistant server."""
|
"""Initialize the HTTP Home Assistant server."""
|
||||||
app = self.app = web.Application(
|
app = self.app = web.Application(
|
||||||
middlewares=[], client_max_size=MAX_CLIENT_SIZE
|
middlewares=[], client_max_size=MAX_CLIENT_SIZE
|
||||||
@ -237,10 +256,10 @@ class HomeAssistantHTTP:
|
|||||||
self.is_ban_enabled = is_ban_enabled
|
self.is_ban_enabled = is_ban_enabled
|
||||||
self.ssl_profile = ssl_profile
|
self.ssl_profile = ssl_profile
|
||||||
self._handler = None
|
self._handler = None
|
||||||
self.runner = None
|
self.runner: web.AppRunner | None = None
|
||||||
self.site = None
|
self.site: HomeAssistantTCPSite | None = None
|
||||||
|
|
||||||
def register_view(self, view):
|
def register_view(self, view: HomeAssistantView) -> None:
|
||||||
"""Register a view with the WSGI server.
|
"""Register a view with the WSGI server.
|
||||||
|
|
||||||
The view argument must be a class that inherits from HomeAssistantView.
|
The view argument must be a class that inherits from HomeAssistantView.
|
||||||
@ -261,7 +280,13 @@ class HomeAssistantHTTP:
|
|||||||
|
|
||||||
view.register(self.app, self.app.router)
|
view.register(self.app, self.app.router)
|
||||||
|
|
||||||
def register_redirect(self, url, redirect_to, *, redirect_exc=HTTPMovedPermanently):
|
def register_redirect(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
redirect_to: StrOrURL,
|
||||||
|
*,
|
||||||
|
redirect_exc: type[HTTPRedirection] = HTTPMovedPermanently,
|
||||||
|
) -> None:
|
||||||
"""Register a redirect with the server.
|
"""Register a redirect with the server.
|
||||||
|
|
||||||
If given this must be either a string or callable. In case of a
|
If given this must be either a string or callable. In case of a
|
||||||
@ -271,38 +296,39 @@ class HomeAssistantHTTP:
|
|||||||
rule syntax.
|
rule syntax.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def redirect(request):
|
async def redirect(request: web.Request) -> web.StreamResponse:
|
||||||
"""Redirect to location."""
|
"""Redirect to location."""
|
||||||
raise redirect_exc(redirect_to)
|
# Should be instance of aiohttp.web_exceptions._HTTPMove.
|
||||||
|
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
|
||||||
|
|
||||||
self.app.router.add_route("GET", url, redirect)
|
self.app.router.add_route("GET", url, redirect)
|
||||||
|
|
||||||
def register_static_path(self, url_path, path, cache_headers=True):
|
def register_static_path(
|
||||||
|
self, url_path: str, path: str, cache_headers: bool = True
|
||||||
|
) -> web.FileResponse | None:
|
||||||
"""Register a folder or file to serve as a static path."""
|
"""Register a folder or file to serve as a static path."""
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
if cache_headers:
|
if cache_headers:
|
||||||
resource = CachingStaticResource
|
resource: type[
|
||||||
|
CachingStaticResource | web.StaticResource
|
||||||
|
] = CachingStaticResource
|
||||||
else:
|
else:
|
||||||
resource = web.StaticResource
|
resource = web.StaticResource
|
||||||
self.app.router.register_resource(resource(url_path, path))
|
self.app.router.register_resource(resource(url_path, path))
|
||||||
return
|
return None
|
||||||
|
|
||||||
if cache_headers:
|
async def serve_file(request: web.Request) -> web.FileResponse:
|
||||||
|
"""Serve file from disk."""
|
||||||
async def serve_file(request):
|
if cache_headers:
|
||||||
"""Serve file from disk."""
|
|
||||||
return web.FileResponse(path, headers=CACHE_HEADERS)
|
return web.FileResponse(path, headers=CACHE_HEADERS)
|
||||||
|
return web.FileResponse(path)
|
||||||
else:
|
|
||||||
|
|
||||||
async def serve_file(request):
|
|
||||||
"""Serve file from disk."""
|
|
||||||
return web.FileResponse(path)
|
|
||||||
|
|
||||||
self.app.router.add_route("GET", url_path, serve_file)
|
self.app.router.add_route("GET", url_path, serve_file)
|
||||||
|
return None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self) -> None:
|
||||||
"""Start the aiohttp server."""
|
"""Start the aiohttp server."""
|
||||||
|
context: ssl.SSLContext | None
|
||||||
if self.ssl_certificate:
|
if self.ssl_certificate:
|
||||||
try:
|
try:
|
||||||
if self.ssl_profile == SSL_INTERMEDIATE:
|
if self.ssl_profile == SSL_INTERMEDIATE:
|
||||||
@ -334,7 +360,7 @@ class HomeAssistantHTTP:
|
|||||||
# This will now raise a RunTimeError.
|
# This will now raise a RunTimeError.
|
||||||
# To work around this we now prevent the router from getting frozen
|
# To work around this we now prevent the router from getting frozen
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
self.app._router.freeze = lambda: None
|
self.app._router.freeze = lambda: None # type: ignore[assignment]
|
||||||
|
|
||||||
self.runner = web.AppRunner(self.app)
|
self.runner = web.AppRunner(self.app)
|
||||||
await self.runner.setup()
|
await self.runner.setup()
|
||||||
@ -351,17 +377,19 @@ class HomeAssistantHTTP:
|
|||||||
|
|
||||||
_LOGGER.info("Now listening on port %d", self.server_port)
|
_LOGGER.info("Now listening on port %d", self.server_port)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
"""Stop the aiohttp server."""
|
"""Stop the aiohttp server."""
|
||||||
await self.site.stop()
|
if self.site is not None:
|
||||||
await self.runner.cleanup()
|
await self.site.stop()
|
||||||
|
if self.runner is not None:
|
||||||
|
await self.runner.cleanup()
|
||||||
|
|
||||||
|
|
||||||
async def start_http_server_and_save_config(
|
async def start_http_server_and_save_config(
|
||||||
hass: HomeAssistant, conf: dict, server: HomeAssistantHTTP
|
hass: HomeAssistant, conf: dict, server: HomeAssistantHTTP
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Startup the http server and save the config."""
|
"""Startup the http server and save the config."""
|
||||||
await server.start() # type: ignore
|
await server.start()
|
||||||
|
|
||||||
# If we are set up successful, we store the HTTP settings for safe mode.
|
# If we are set up successful, we store the HTTP settings for safe mode.
|
||||||
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
store = storage.Store(hass, STORAGE_VERSION, STORAGE_KEY)
|
||||||
|
@ -1,28 +1,33 @@
|
|||||||
"""Authentication for HTTP component."""
|
"""Authentication for HTTP component."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from datetime import timedelta
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
from typing import Final
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from aiohttp import hdrs
|
from aiohttp import hdrs
|
||||||
from aiohttp.web import middleware
|
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DATA_API_PASSWORD = "api_password"
|
DATA_API_PASSWORD: Final = "api_password"
|
||||||
DATA_SIGN_SECRET = "http.auth.sign_secret"
|
DATA_SIGN_SECRET: Final = "http.auth.sign_secret"
|
||||||
SIGN_QUERY_PARAM = "authSig"
|
SIGN_QUERY_PARAM: Final = "authSig"
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_sign_path(hass, refresh_token_id, path, expiration):
|
def async_sign_path(
|
||||||
|
hass: HomeAssistant, refresh_token_id: str, path: str, expiration: timedelta
|
||||||
|
) -> str:
|
||||||
"""Sign a path for temporary access without auth header."""
|
"""Sign a path for temporary access without auth header."""
|
||||||
secret = hass.data.get(DATA_SIGN_SECRET)
|
secret = hass.data.get(DATA_SIGN_SECRET)
|
||||||
|
|
||||||
@ -44,17 +49,19 @@ def async_sign_path(hass, refresh_token_id, path, expiration):
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_auth(hass, app):
|
def setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||||
"""Create auth middleware for the app."""
|
"""Create auth middleware for the app."""
|
||||||
|
|
||||||
async def async_validate_auth_header(request):
|
async def async_validate_auth_header(request: Request) -> bool:
|
||||||
"""
|
"""
|
||||||
Test authorization header against access token.
|
Test authorization header against access token.
|
||||||
|
|
||||||
Basic auth_type is legacy code, should be removed with api_password.
|
Basic auth_type is legacy code, should be removed with api_password.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION).split(" ", 1)
|
auth_type, auth_val = request.headers.get(hdrs.AUTHORIZATION, "").split(
|
||||||
|
" ", 1
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# If no space in authorization header
|
# If no space in authorization header
|
||||||
return False
|
return False
|
||||||
@ -71,7 +78,7 @@ def setup_auth(hass, app):
|
|||||||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def async_validate_signed_request(request):
|
async def async_validate_signed_request(request: Request) -> bool:
|
||||||
"""Validate a signed request."""
|
"""Validate a signed request."""
|
||||||
secret = hass.data.get(DATA_SIGN_SECRET)
|
secret = hass.data.get(DATA_SIGN_SECRET)
|
||||||
|
|
||||||
@ -103,7 +110,9 @@ def setup_auth(hass, app):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
async def auth_middleware(request, handler):
|
async def auth_middleware(
|
||||||
|
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||||
|
) -> StreamResponse:
|
||||||
"""Authenticate as middleware."""
|
"""Authenticate as middleware."""
|
||||||
authenticated = False
|
authenticated = False
|
||||||
|
|
||||||
|
@ -2,13 +2,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
import logging
|
import logging
|
||||||
from socket import gethostbyaddr, herror
|
from socket import gethostbyaddr, herror
|
||||||
|
from typing import Any, Final
|
||||||
|
|
||||||
from aiohttp.web import middleware
|
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -19,33 +21,33 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.util import dt as dt_util, yaml
|
from homeassistant.util import dt as dt_util, yaml
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
from .view import HomeAssistantView
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER: Final = logging.getLogger(__name__)
|
||||||
|
|
||||||
KEY_BANNED_IPS = "ha_banned_ips"
|
KEY_BANNED_IPS: Final = "ha_banned_ips"
|
||||||
KEY_FAILED_LOGIN_ATTEMPTS = "ha_failed_login_attempts"
|
KEY_FAILED_LOGIN_ATTEMPTS: Final = "ha_failed_login_attempts"
|
||||||
KEY_LOGIN_THRESHOLD = "ha_login_threshold"
|
KEY_LOGIN_THRESHOLD: Final = "ha_login_threshold"
|
||||||
|
|
||||||
NOTIFICATION_ID_BAN = "ip-ban"
|
NOTIFICATION_ID_BAN: Final = "ip-ban"
|
||||||
NOTIFICATION_ID_LOGIN = "http-login"
|
NOTIFICATION_ID_LOGIN: Final = "http-login"
|
||||||
|
|
||||||
IP_BANS_FILE = "ip_bans.yaml"
|
IP_BANS_FILE: Final = "ip_bans.yaml"
|
||||||
ATTR_BANNED_AT = "banned_at"
|
ATTR_BANNED_AT: Final = "banned_at"
|
||||||
|
|
||||||
SCHEMA_IP_BAN_ENTRY = vol.Schema(
|
SCHEMA_IP_BAN_ENTRY: Final = vol.Schema(
|
||||||
{vol.Optional("banned_at"): vol.Any(None, cv.datetime)}
|
{vol.Optional("banned_at"): vol.Any(None, cv.datetime)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_bans(hass, app, login_threshold):
|
def setup_bans(hass: HomeAssistant, app: Application, login_threshold: int) -> None:
|
||||||
"""Create IP Ban middleware for the app."""
|
"""Create IP Ban middleware for the app."""
|
||||||
app.middlewares.append(ban_middleware)
|
app.middlewares.append(ban_middleware)
|
||||||
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
app[KEY_FAILED_LOGIN_ATTEMPTS] = defaultdict(int)
|
||||||
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
app[KEY_LOGIN_THRESHOLD] = login_threshold
|
||||||
|
|
||||||
async def ban_startup(app):
|
async def ban_startup(app: Application) -> None:
|
||||||
"""Initialize bans when app starts up."""
|
"""Initialize bans when app starts up."""
|
||||||
app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
|
app[KEY_BANNED_IPS] = await async_load_ip_bans_config(
|
||||||
hass, hass.config.path(IP_BANS_FILE)
|
hass, hass.config.path(IP_BANS_FILE)
|
||||||
@ -55,7 +57,9 @@ def setup_bans(hass, app, login_threshold):
|
|||||||
|
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
async def ban_middleware(request, handler):
|
async def ban_middleware(
|
||||||
|
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||||
|
) -> StreamResponse:
|
||||||
"""IP Ban middleware."""
|
"""IP Ban middleware."""
|
||||||
if KEY_BANNED_IPS not in request.app:
|
if KEY_BANNED_IPS not in request.app:
|
||||||
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
_LOGGER.error("IP Ban middleware loaded but banned IPs not loaded")
|
||||||
@ -77,10 +81,14 @@ async def ban_middleware(request, handler):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def log_invalid_auth(func):
|
def log_invalid_auth(
|
||||||
|
func: Callable[..., Awaitable[StreamResponse]]
|
||||||
|
) -> Callable[..., Awaitable[StreamResponse]]:
|
||||||
"""Decorate function to handle invalid auth or failed login attempts."""
|
"""Decorate function to handle invalid auth or failed login attempts."""
|
||||||
|
|
||||||
async def handle_req(view, request, *args, **kwargs):
|
async def handle_req(
|
||||||
|
view: HomeAssistantView, request: Request, *args: Any, **kwargs: Any
|
||||||
|
) -> StreamResponse:
|
||||||
"""Try to log failed login attempts if response status >= 400."""
|
"""Try to log failed login attempts if response status >= 400."""
|
||||||
resp = await func(view, request, *args, **kwargs)
|
resp = await func(view, request, *args, **kwargs)
|
||||||
if resp.status >= HTTP_BAD_REQUEST:
|
if resp.status >= HTTP_BAD_REQUEST:
|
||||||
@ -90,7 +98,7 @@ def log_invalid_auth(func):
|
|||||||
return handle_req
|
return handle_req
|
||||||
|
|
||||||
|
|
||||||
async def process_wrong_login(request):
|
async def process_wrong_login(request: Request) -> None:
|
||||||
"""Process a wrong login attempt.
|
"""Process a wrong login attempt.
|
||||||
|
|
||||||
Increase failed login attempts counter for remote IP address.
|
Increase failed login attempts counter for remote IP address.
|
||||||
@ -152,7 +160,7 @@ async def process_wrong_login(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def process_success_login(request):
|
async def process_success_login(request: Request) -> None:
|
||||||
"""Process a success login attempt.
|
"""Process a success login attempt.
|
||||||
|
|
||||||
Reset failed login attempts counter for remote IP address.
|
Reset failed login attempts counter for remote IP address.
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""HTTP specific constants."""
|
"""HTTP specific constants."""
|
||||||
KEY_AUTHENTICATED = "ha_authenticated"
|
from typing import Final
|
||||||
KEY_HASS = "hass"
|
|
||||||
KEY_HASS_USER = "hass_user"
|
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
||||||
KEY_HASS_REFRESH_TOKEN_ID = "hass_refresh_token_id"
|
KEY_HASS: Final = "hass"
|
||||||
|
KEY_HASS_USER: Final = "hass_user"
|
||||||
|
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||||
|
@ -1,24 +1,33 @@
|
|||||||
"""Provide CORS support for the HTTP component."""
|
"""Provide CORS support for the HTTP component."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
|
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
|
||||||
from aiohttp.web_urldispatcher import Resource, ResourceRoute, StaticResource
|
from aiohttp.web import Application
|
||||||
|
from aiohttp.web_urldispatcher import (
|
||||||
|
AbstractResource,
|
||||||
|
AbstractRoute,
|
||||||
|
Resource,
|
||||||
|
ResourceRoute,
|
||||||
|
StaticResource,
|
||||||
|
)
|
||||||
|
|
||||||
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
|
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
ALLOWED_CORS_HEADERS: Final[list[str]] = [
|
||||||
|
|
||||||
ALLOWED_CORS_HEADERS = [
|
|
||||||
ORIGIN,
|
ORIGIN,
|
||||||
ACCEPT,
|
ACCEPT,
|
||||||
HTTP_HEADER_X_REQUESTED_WITH,
|
HTTP_HEADER_X_REQUESTED_WITH,
|
||||||
CONTENT_TYPE,
|
CONTENT_TYPE,
|
||||||
AUTHORIZATION,
|
AUTHORIZATION,
|
||||||
]
|
]
|
||||||
VALID_CORS_TYPES = (Resource, ResourceRoute, StaticResource)
|
VALID_CORS_TYPES: Final = (Resource, ResourceRoute, StaticResource)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_cors(app, origins):
|
def setup_cors(app: Application, origins: list[str]) -> None:
|
||||||
"""Set up CORS."""
|
"""Set up CORS."""
|
||||||
# This import should remain here. That way the HTTP integration can always
|
# This import should remain here. That way the HTTP integration can always
|
||||||
# be imported by other integrations without it's requirements being installed.
|
# be imported by other integrations without it's requirements being installed.
|
||||||
@ -37,9 +46,12 @@ def setup_cors(app, origins):
|
|||||||
|
|
||||||
cors_added = set()
|
cors_added = set()
|
||||||
|
|
||||||
def _allow_cors(route, config=None):
|
def _allow_cors(
|
||||||
|
route: AbstractRoute | AbstractResource,
|
||||||
|
config: dict[str, aiohttp_cors.ResourceOptions] | None = None,
|
||||||
|
) -> None:
|
||||||
"""Allow CORS on a route."""
|
"""Allow CORS on a route."""
|
||||||
if hasattr(route, "resource"):
|
if isinstance(route, AbstractRoute):
|
||||||
path = route.resource
|
path = route.resource
|
||||||
else:
|
else:
|
||||||
path = route
|
path = route
|
||||||
@ -47,16 +59,16 @@ def setup_cors(app, origins):
|
|||||||
if not isinstance(path, VALID_CORS_TYPES):
|
if not isinstance(path, VALID_CORS_TYPES):
|
||||||
return
|
return
|
||||||
|
|
||||||
path = path.canonical
|
path_str = path.canonical
|
||||||
|
|
||||||
if path.startswith("/api/hassio_ingress/"):
|
if path_str.startswith("/api/hassio_ingress/"):
|
||||||
return
|
return
|
||||||
|
|
||||||
if path in cors_added:
|
if path_str in cors_added:
|
||||||
return
|
return
|
||||||
|
|
||||||
cors.add(route, config)
|
cors.add(route, config)
|
||||||
cors_added.add(path)
|
cors_added.add(path_str)
|
||||||
|
|
||||||
app["allow_cors"] = lambda route: _allow_cors(
|
app["allow_cors"] = lambda route: _allow_cors(
|
||||||
route,
|
route,
|
||||||
@ -70,7 +82,7 @@ def setup_cors(app, origins):
|
|||||||
if not origins:
|
if not origins:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def cors_startup(app):
|
async def cors_startup(app: Application) -> None:
|
||||||
"""Initialize CORS when app starts up."""
|
"""Initialize CORS when app starts up."""
|
||||||
for resource in list(app.router.resources()):
|
for resource in list(app.router.resources()):
|
||||||
_allow_cors(resource)
|
_allow_cors(resource)
|
||||||
|
@ -1,19 +1,20 @@
|
|||||||
"""Middleware to handle forwarded data by a reverse proxy."""
|
"""Middleware to handle forwarded data by a reverse proxy."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO
|
from aiohttp.hdrs import X_FORWARDED_FOR, X_FORWARDED_HOST, X_FORWARDED_PROTO
|
||||||
from aiohttp.web import HTTPBadRequest, middleware
|
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_setup_forwarded(app, trusted_proxies):
|
def async_setup_forwarded(app: Application, trusted_proxies: list[str]) -> None:
|
||||||
"""Create forwarded middleware for the app.
|
"""Create forwarded middleware for the app.
|
||||||
|
|
||||||
Process IP addresses, proto and host information in the forwarded for headers.
|
Process IP addresses, proto and host information in the forwarded for headers.
|
||||||
@ -60,17 +61,20 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
async def forwarded_middleware(request, handler):
|
async def forwarded_middleware(
|
||||||
|
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||||
|
) -> StreamResponse:
|
||||||
"""Process forwarded data by a reverse proxy."""
|
"""Process forwarded data by a reverse proxy."""
|
||||||
overrides = {}
|
overrides: dict[str, str] = {}
|
||||||
|
|
||||||
# Handle X-Forwarded-For
|
# Handle X-Forwarded-For
|
||||||
forwarded_for_headers = request.headers.getall(X_FORWARDED_FOR, [])
|
forwarded_for_headers: list[str] = request.headers.getall(X_FORWARDED_FOR, [])
|
||||||
if not forwarded_for_headers:
|
if not forwarded_for_headers:
|
||||||
# No forwarding headers, continue as normal
|
# No forwarding headers, continue as normal
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
||||||
# Ensure the IP of the connected peer is trusted
|
# Ensure the IP of the connected peer is trusted
|
||||||
|
assert request.transport is not None
|
||||||
connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
|
connected_ip = ip_address(request.transport.get_extra_info("peername")[0])
|
||||||
if not any(connected_ip in trusted_proxy for trusted_proxy in trusted_proxies):
|
if not any(connected_ip in trusted_proxy for trusted_proxy in trusted_proxies):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@ -111,7 +115,9 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||||||
overrides["remote"] = str(forwarded_for[-1])
|
overrides["remote"] = str(forwarded_for[-1])
|
||||||
|
|
||||||
# Handle X-Forwarded-Proto
|
# Handle X-Forwarded-Proto
|
||||||
forwarded_proto_headers = request.headers.getall(X_FORWARDED_PROTO, [])
|
forwarded_proto_headers: list[str] = request.headers.getall(
|
||||||
|
X_FORWARDED_PROTO, []
|
||||||
|
)
|
||||||
if forwarded_proto_headers:
|
if forwarded_proto_headers:
|
||||||
if len(forwarded_proto_headers) > 1:
|
if len(forwarded_proto_headers) > 1:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
@ -151,7 +157,7 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||||||
overrides["scheme"] = forwarded_proto[forwarded_for_index]
|
overrides["scheme"] = forwarded_proto[forwarded_for_index]
|
||||||
|
|
||||||
# Handle X-Forwarded-Host
|
# Handle X-Forwarded-Host
|
||||||
forwarded_host_headers = request.headers.getall(X_FORWARDED_HOST, [])
|
forwarded_host_headers: list[str] = request.headers.getall(X_FORWARDED_HOST, [])
|
||||||
if forwarded_host_headers:
|
if forwarded_host_headers:
|
||||||
# Multiple X-Forwarded-Host headers
|
# Multiple X-Forwarded-Host headers
|
||||||
if len(forwarded_host_headers) > 1:
|
if len(forwarded_host_headers) > 1:
|
||||||
@ -168,7 +174,7 @@ def async_setup_forwarded(app, trusted_proxies):
|
|||||||
overrides["host"] = forwarded_host
|
overrides["host"] = forwarded_host
|
||||||
|
|
||||||
# Done, create a new request based on gathered data.
|
# Done, create a new request based on gathered data.
|
||||||
request = request.clone(**overrides)
|
request = request.clone(**overrides) # type: ignore[arg-type]
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
||||||
app.middlewares.append(forwarded_middleware)
|
app.middlewares.append(forwarded_middleware)
|
||||||
|
@ -1,18 +1,24 @@
|
|||||||
"""Middleware to set the request context."""
|
"""Middleware to set the request context."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from aiohttp.web import middleware
|
from collections.abc import Awaitable, Callable
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_request_context(app, context):
|
def setup_request_context(
|
||||||
|
app: Application, context: ContextVar[Request | None]
|
||||||
|
) -> None:
|
||||||
"""Create request context middleware for the app."""
|
"""Create request context middleware for the app."""
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
async def request_context_middleware(request, handler):
|
async def request_context_middleware(
|
||||||
|
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||||
|
) -> StreamResponse:
|
||||||
"""Request context middleware."""
|
"""Request context middleware."""
|
||||||
context.set(request)
|
context.set(request)
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
@ -1,17 +1,19 @@
|
|||||||
"""Middleware to add some basic security filtering to requests."""
|
"""Middleware to add some basic security filtering to requests."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
from aiohttp.web import HTTPBadRequest, middleware
|
from aiohttp.web import Application, HTTPBadRequest, Request, StreamResponse, middleware
|
||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
FILTERS = re.compile(
|
FILTERS: Final = re.compile(
|
||||||
r"(?:"
|
r"(?:"
|
||||||
|
|
||||||
# Common exploits
|
# Common exploits
|
||||||
@ -34,12 +36,14 @@ FILTERS = re.compile(
|
|||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_security_filter(app):
|
def setup_security_filter(app: Application) -> None:
|
||||||
"""Create security filter middleware for the app."""
|
"""Create security filter middleware for the app."""
|
||||||
|
|
||||||
@middleware
|
@middleware
|
||||||
async def security_filter_middleware(request, handler):
|
async def security_filter_middleware(
|
||||||
"""Process request and block commonly known exploit attempts."""
|
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||||
|
) -> StreamResponse:
|
||||||
|
"""Process request and tblock commonly known exploit attempts."""
|
||||||
if FILTERS.search(request.path):
|
if FILTERS.search(request.path):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Filtered a potential harmful request to: %s", request.raw_path
|
"Filtered a potential harmful request to: %s", request.raw_path
|
||||||
|
@ -1,21 +1,25 @@
|
|||||||
"""Static file handling for HTTP component."""
|
"""Static file handling for HTTP component."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
from aiohttp import hdrs
|
from aiohttp import hdrs
|
||||||
from aiohttp.web import FileResponse
|
from aiohttp.web import FileResponse, Request, StreamResponse
|
||||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound
|
from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound
|
||||||
from aiohttp.web_urldispatcher import StaticResource
|
from aiohttp.web_urldispatcher import StaticResource
|
||||||
|
|
||||||
# mypy: allow-untyped-defs
|
CACHE_TIME: Final = 31 * 86400 # = 1 month
|
||||||
|
CACHE_HEADERS: Final[Mapping[str, str]] = {
|
||||||
CACHE_TIME = 31 * 86400 # = 1 month
|
hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"
|
||||||
CACHE_HEADERS = {hdrs.CACHE_CONTROL: f"public, max-age={CACHE_TIME}"}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CachingStaticResource(StaticResource):
|
class CachingStaticResource(StaticResource):
|
||||||
"""Static Resource handler that will add cache headers."""
|
"""Static Resource handler that will add cache headers."""
|
||||||
|
|
||||||
async def _handle(self, request):
|
async def _handle(self, request: Request) -> StreamResponse:
|
||||||
rel_url = request.match_info["filename"]
|
rel_url = request.match_info["filename"]
|
||||||
try:
|
try:
|
||||||
filename = Path(rel_url)
|
filename = Path(rel_url)
|
||||||
@ -42,7 +46,6 @@ class CachingStaticResource(StaticResource):
|
|||||||
return FileResponse(
|
return FileResponse(
|
||||||
filepath,
|
filepath,
|
||||||
chunk_size=self._chunk_size,
|
chunk_size=self._chunk_size,
|
||||||
# type ignore: https://github.com/aio-libs/aiohttp/pull/3976
|
headers=CACHE_HEADERS,
|
||||||
headers=CACHE_HEADERS, # type: ignore
|
|
||||||
)
|
)
|
||||||
raise HTTPNotFound
|
raise HTTPNotFound
|
||||||
|
@ -2,9 +2,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable
|
from typing import Any
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.typedefs import LooseHeaders
|
from aiohttp.typedefs import LooseHeaders
|
||||||
@ -13,6 +14,7 @@ from aiohttp.web_exceptions import (
|
|||||||
HTTPInternalServerError,
|
HTTPInternalServerError,
|
||||||
HTTPUnauthorized,
|
HTTPUnauthorized,
|
||||||
)
|
)
|
||||||
|
from aiohttp.web_urldispatcher import AbstractRoute
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import exceptions
|
from homeassistant import exceptions
|
||||||
@ -81,7 +83,7 @@ class HomeAssistantView:
|
|||||||
"""Register the view with a router."""
|
"""Register the view with a router."""
|
||||||
assert self.url is not None, "No url set for view"
|
assert self.url is not None, "No url set for view"
|
||||||
urls = [self.url] + self.extra_urls
|
urls = [self.url] + self.extra_urls
|
||||||
routes = []
|
routes: list[AbstractRoute] = []
|
||||||
|
|
||||||
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
||||||
handler = getattr(self, method, None)
|
handler = getattr(self, method, None)
|
||||||
@ -101,7 +103,9 @@ class HomeAssistantView:
|
|||||||
app["allow_cors"](route)
|
app["allow_cors"](route)
|
||||||
|
|
||||||
|
|
||||||
def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable:
|
def request_handler_factory(
|
||||||
|
view: HomeAssistantView, handler: Callable
|
||||||
|
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
|
||||||
"""Wrap the handler classes."""
|
"""Wrap the handler classes."""
|
||||||
assert asyncio.iscoroutinefunction(handler) or is_callback(
|
assert asyncio.iscoroutinefunction(handler) or is_callback(
|
||||||
handler
|
handler
|
||||||
|
@ -23,7 +23,7 @@ class HomeAssistantTCPSite(web.BaseSite):
|
|||||||
|
|
||||||
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")
|
__slots__ = ("_host", "_port", "_reuse_address", "_reuse_port", "_hosturl")
|
||||||
|
|
||||||
def __init__( # noqa: D107
|
def __init__(
|
||||||
self,
|
self,
|
||||||
runner: web.BaseRunner,
|
runner: web.BaseRunner,
|
||||||
host: None | str | list[str],
|
host: None | str | list[str],
|
||||||
@ -35,6 +35,7 @@ class HomeAssistantTCPSite(web.BaseSite):
|
|||||||
reuse_address: bool | None = None,
|
reuse_address: bool | None = None,
|
||||||
reuse_port: bool | None = None,
|
reuse_port: bool | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize HomeAssistantTCPSite."""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
runner,
|
runner,
|
||||||
shutdown_timeout=shutdown_timeout,
|
shutdown_timeout=shutdown_timeout,
|
||||||
@ -47,12 +48,14 @@ class HomeAssistantTCPSite(web.BaseSite):
|
|||||||
self._reuse_port = reuse_port
|
self._reuse_port = reuse_port
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str: # noqa: D102
|
def name(self) -> str:
|
||||||
|
"""Return server URL."""
|
||||||
scheme = "https" if self._ssl_context else "http"
|
scheme = "https" if self._ssl_context else "http"
|
||||||
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
|
host = self._host[0] if isinstance(self._host, list) else "0.0.0.0"
|
||||||
return str(URL.build(scheme=scheme, host=host, port=self._port))
|
return str(URL.build(scheme=scheme, host=host, port=self._port))
|
||||||
|
|
||||||
async def start(self) -> None: # noqa: D102
|
async def start(self) -> None:
|
||||||
|
"""Start server."""
|
||||||
await super().start()
|
await super().start()
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
server = self._runner.server
|
server = self._runner.server
|
||||||
|
@ -593,7 +593,7 @@ SERVICE_TOGGLE_COVER_TILT = "toggle_cover_tilt"
|
|||||||
SERVICE_SELECT_OPTION = "select_option"
|
SERVICE_SELECT_OPTION = "select_option"
|
||||||
|
|
||||||
# #### API / REMOTE ####
|
# #### API / REMOTE ####
|
||||||
SERVER_PORT = 8123
|
SERVER_PORT: Final = 8123
|
||||||
|
|
||||||
URL_ROOT = "/"
|
URL_ROOT = "/"
|
||||||
URL_API = "/api/"
|
URL_API = "/api/"
|
||||||
|
@ -334,7 +334,7 @@ def async_register_implementation(
|
|||||||
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
|
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
|
||||||
DATA_VIEW_REGISTERED, False
|
DATA_VIEW_REGISTERED, False
|
||||||
):
|
):
|
||||||
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
|
hass.http.register_view(OAuth2AuthorizeCallbackView())
|
||||||
hass.data[DATA_VIEW_REGISTERED] = True
|
hass.data[DATA_VIEW_REGISTERED] = True
|
||||||
|
|
||||||
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
|
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user