mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 01:37:08 +00:00
Ensure headers middleware handles errors too (#98397)
This commit is contained in:
parent
54223fe06c
commit
85c2216cd7
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
from aiohttp.web_exceptions import HTTPException
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
@ -12,20 +13,29 @@ from homeassistant.core import callback
|
||||
def setup_headers(app: Application, use_x_frame_options: bool) -> None:
|
||||
"""Create headers middleware for the app."""
|
||||
|
||||
added_headers = {
|
||||
"Referrer-Policy": "no-referrer",
|
||||
"X-Content-Type-Options": "nosniff",
|
||||
"Server": "", # Empty server header, to prevent aiohttp of setting one.
|
||||
}
|
||||
|
||||
if use_x_frame_options:
|
||||
added_headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||
|
||||
@middleware
|
||||
async def headers_middleware(
|
||||
request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
|
||||
) -> StreamResponse:
|
||||
"""Process request and add headers to the responses."""
|
||||
try:
|
||||
response = await handler(request)
|
||||
response.headers["Referrer-Policy"] = "no-referrer"
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
except HTTPException as err:
|
||||
for key, value in added_headers.items():
|
||||
err.headers[key] = value
|
||||
raise err
|
||||
|
||||
# Set an empty server header, to prevent aiohttp of setting one.
|
||||
response.headers["Server"] = ""
|
||||
|
||||
if use_x_frame_options:
|
||||
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||
for key, value in added_headers.items():
|
||||
response.headers[key] = value
|
||||
|
||||
return response
|
||||
|
||||
|
@ -2,21 +2,28 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||
|
||||
from homeassistant.components.http.headers import setup_headers
|
||||
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
|
||||
async def mock_handler(request):
|
||||
async def mock_handler(_: web.Request) -> web.Response:
|
||||
"""Return OK."""
|
||||
return web.Response(text="OK")
|
||||
|
||||
|
||||
async def mock_handler_error(_: web.Request) -> web.Response:
|
||||
"""Return Unauthorized."""
|
||||
raise HTTPUnauthorized(text="Ah ah ah, you didn't say the magic word")
|
||||
|
||||
|
||||
async def test_headers_added(aiohttp_client: ClientSessionGenerator) -> None:
|
||||
"""Test that headers are being added on each request."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/", mock_handler)
|
||||
app.router.add_get("/error", mock_handler_error)
|
||||
|
||||
setup_headers(app, use_x_frame_options=True)
|
||||
|
||||
@ -29,11 +36,20 @@ async def test_headers_added(aiohttp_client: ClientSessionGenerator) -> None:
|
||||
assert resp.headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert resp.headers["X-Frame-Options"] == "SAMEORIGIN"
|
||||
|
||||
resp = await mock_api_client.get("/error")
|
||||
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
assert resp.headers["Referrer-Policy"] == "no-referrer"
|
||||
assert resp.headers["Server"] == ""
|
||||
assert resp.headers["X-Content-Type-Options"] == "nosniff"
|
||||
assert resp.headers["X-Frame-Options"] == "SAMEORIGIN"
|
||||
|
||||
|
||||
async def test_allow_framing(aiohttp_client: ClientSessionGenerator) -> None:
|
||||
"""Test that we allow framing when disabled."""
|
||||
app = web.Application()
|
||||
app.router.add_get("/", mock_handler)
|
||||
app.router.add_get("/error", mock_handler_error)
|
||||
|
||||
setup_headers(app, use_x_frame_options=False)
|
||||
|
||||
@ -42,3 +58,9 @@ async def test_allow_framing(aiohttp_client: ClientSessionGenerator) -> None:
|
||||
|
||||
assert resp.status == HTTPStatus.OK
|
||||
assert "X-Frame-Options" not in resp.headers
|
||||
|
||||
mock_api_client = await aiohttp_client(app)
|
||||
resp = await mock_api_client.get("/error")
|
||||
|
||||
assert resp.status == HTTPStatus.UNAUTHORIZED
|
||||
assert "X-Frame-Options" not in resp.headers
|
||||
|
Loading…
x
Reference in New Issue
Block a user