Ensure headers middleware handles errors too (#98397)

This commit is contained in:
Franck Nijhof 2023-08-14 17:48:11 +02:00 committed by GitHub
parent 54223fe06c
commit 85c2216cd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 9 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from collections.abc import Awaitable, Callable
from aiohttp.web import Application, Request, StreamResponse, middleware
from aiohttp.web_exceptions import HTTPException
from homeassistant.core import callback
@ -12,20 +13,29 @@ from homeassistant.core import callback
def setup_headers(app: Application, use_x_frame_options: bool) -> None:
"""Create headers middleware for the app."""
added_headers = {
"Referrer-Policy": "no-referrer",
"X-Content-Type-Options": "nosniff",
"Server": "", # Empty server header, to prevent aiohttp of setting one.
}
if use_x_frame_options:
added_headers["X-Frame-Options"] = "SAMEORIGIN"
@middleware
async def headers_middleware(
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
"""Process request and add headers to the responses."""
try:
response = await handler(request)
response.headers["Referrer-Policy"] = "no-referrer"
response.headers["X-Content-Type-Options"] = "nosniff"
except HTTPException as err:
for key, value in added_headers.items():
err.headers[key] = value
raise err
# Set an empty server header, to prevent aiohttp of setting one.
response.headers["Server"] = ""
if use_x_frame_options:
response.headers["X-Frame-Options"] = "SAMEORIGIN"
for key, value in added_headers.items():
response.headers[key] = value
return response

View File

@ -2,21 +2,28 @@
from http import HTTPStatus
from aiohttp import web
from aiohttp.web_exceptions import HTTPUnauthorized
from homeassistant.components.http.headers import setup_headers
from tests.typing import ClientSessionGenerator
async def mock_handler(request):
async def mock_handler(_: web.Request) -> web.Response:
"""Return OK."""
return web.Response(text="OK")
async def mock_handler_error(_: web.Request) -> web.Response:
"""Return Unauthorized."""
raise HTTPUnauthorized(text="Ah ah ah, you didn't say the magic word")
async def test_headers_added(aiohttp_client: ClientSessionGenerator) -> None:
"""Test that headers are being added on each request."""
app = web.Application()
app.router.add_get("/", mock_handler)
app.router.add_get("/error", mock_handler_error)
setup_headers(app, use_x_frame_options=True)
@ -29,11 +36,20 @@ async def test_headers_added(aiohttp_client: ClientSessionGenerator) -> None:
assert resp.headers["X-Content-Type-Options"] == "nosniff"
assert resp.headers["X-Frame-Options"] == "SAMEORIGIN"
resp = await mock_api_client.get("/error")
assert resp.status == HTTPStatus.UNAUTHORIZED
assert resp.headers["Referrer-Policy"] == "no-referrer"
assert resp.headers["Server"] == ""
assert resp.headers["X-Content-Type-Options"] == "nosniff"
assert resp.headers["X-Frame-Options"] == "SAMEORIGIN"
async def test_allow_framing(aiohttp_client: ClientSessionGenerator) -> None:
"""Test that we allow framing when disabled."""
app = web.Application()
app.router.add_get("/", mock_handler)
app.router.add_get("/error", mock_handler_error)
setup_headers(app, use_x_frame_options=False)
@ -42,3 +58,9 @@ async def test_allow_framing(aiohttp_client: ClientSessionGenerator) -> None:
assert resp.status == HTTPStatus.OK
assert "X-Frame-Options" not in resp.headers
mock_api_client = await aiohttp_client(app)
resp = await mock_api_client.get("/error")
assert resp.status == HTTPStatus.UNAUTHORIZED
assert "X-Frame-Options" not in resp.headers