diff --git a/homeassistant/components/http/headers.py b/homeassistant/components/http/headers.py index b53f354b144..20c0a58967b 100644 --- a/homeassistant/components/http/headers.py +++ b/homeassistant/components/http/headers.py @@ -4,6 +4,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable from aiohttp.web import Application, Request, StreamResponse, middleware +from aiohttp.web_exceptions import HTTPException from homeassistant.core import callback @@ -12,20 +13,29 @@ from homeassistant.core import callback def setup_headers(app: Application, use_x_frame_options: bool) -> None: """Create headers middleware for the app.""" + added_headers = { + "Referrer-Policy": "no-referrer", + "X-Content-Type-Options": "nosniff", + "Server": "", # Empty server header, to prevent aiohttp of setting one. + } + + if use_x_frame_options: + added_headers["X-Frame-Options"] = "SAMEORIGIN" + @middleware async def headers_middleware( request: Request, handler: Callable[[Request], Awaitable[StreamResponse]] ) -> StreamResponse: """Process request and add headers to the responses.""" - response = await handler(request) - response.headers["Referrer-Policy"] = "no-referrer" - response.headers["X-Content-Type-Options"] = "nosniff" + try: + response = await handler(request) + except HTTPException as err: + for key, value in added_headers.items(): + err.headers[key] = value + raise err - # Set an empty server header, to prevent aiohttp of setting one. - response.headers["Server"] = "" - - if use_x_frame_options: - response.headers["X-Frame-Options"] = "SAMEORIGIN" + for key, value in added_headers.items(): + response.headers[key] = value return response diff --git a/tests/components/http/test_headers.py b/tests/components/http/test_headers.py index 6d7dbad68f6..16b897b9f99 100644 --- a/tests/components/http/test_headers.py +++ b/tests/components/http/test_headers.py @@ -2,21 +2,28 @@ from http import HTTPStatus from aiohttp import web +from aiohttp.web_exceptions import HTTPUnauthorized from homeassistant.components.http.headers import setup_headers from tests.typing import ClientSessionGenerator -async def mock_handler(request): +async def mock_handler(_: web.Request) -> web.Response: """Return OK.""" return web.Response(text="OK") +async def mock_handler_error(_: web.Request) -> web.Response: + """Return Unauthorized.""" + raise HTTPUnauthorized(text="Ah ah ah, you didn't say the magic word") + + async def test_headers_added(aiohttp_client: ClientSessionGenerator) -> None: """Test that headers are being added on each request.""" app = web.Application() app.router.add_get("/", mock_handler) + app.router.add_get("/error", mock_handler_error) setup_headers(app, use_x_frame_options=True) @@ -29,11 +36,20 @@ async def test_headers_added(aiohttp_client: ClientSessionGenerator) -> None: assert resp.headers["X-Content-Type-Options"] == "nosniff" assert resp.headers["X-Frame-Options"] == "SAMEORIGIN" + resp = await mock_api_client.get("/error") + + assert resp.status == HTTPStatus.UNAUTHORIZED + assert resp.headers["Referrer-Policy"] == "no-referrer" + assert resp.headers["Server"] == "" + assert resp.headers["X-Content-Type-Options"] == "nosniff" + assert resp.headers["X-Frame-Options"] == "SAMEORIGIN" + async def test_allow_framing(aiohttp_client: ClientSessionGenerator) -> None: """Test that we allow framing when disabled.""" app = web.Application() app.router.add_get("/", mock_handler) + app.router.add_get("/error", mock_handler_error) setup_headers(app, use_x_frame_options=False) @@ -42,3 +58,9 @@ async def test_allow_framing(aiohttp_client: ClientSessionGenerator) -> None: assert resp.status == HTTPStatus.OK assert "X-Frame-Options" not in resp.headers + + mock_api_client = await aiohttp_client(app) + resp = await mock_api_client.get("/error") + + assert resp.status == HTTPStatus.UNAUTHORIZED + assert "X-Frame-Options" not in resp.headers