From 9414e6d47213eacf0391a631cab3fda4759afe64 Mon Sep 17 00:00:00 2001 From: Steve Repsher Date: Tue, 6 Aug 2024 10:17:54 -0400 Subject: [PATCH] Adapt static resource handler to aiohttp 3.10 (#123166) --- homeassistant/components/http/static.py | 79 +++++++------------------ tests/components/http/test_static.py | 35 +++-------- 2 files changed, 30 insertions(+), 84 deletions(-) diff --git a/homeassistant/components/http/static.py b/homeassistant/components/http/static.py index a7280fb9b2f..29c5840a4bf 100644 --- a/homeassistant/components/http/static.py +++ b/homeassistant/components/http/static.py @@ -3,81 +3,46 @@ from __future__ import annotations from collections.abc import Mapping -import mimetypes from pathlib import Path from typing import Final -from aiohttp import hdrs +from aiohttp.hdrs import CACHE_CONTROL, CONTENT_TYPE from aiohttp.web import FileResponse, Request, StreamResponse -from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound +from aiohttp.web_fileresponse import CONTENT_TYPES, FALLBACK_CONTENT_TYPE from aiohttp.web_urldispatcher import StaticResource from lru import LRU -from .const import KEY_HASS - CACHE_TIME: Final = 31 * 86400 # = 1 month CACHE_HEADER = f"public, max-age={CACHE_TIME}" -CACHE_HEADERS: Mapping[str, str] = {hdrs.CACHE_CONTROL: CACHE_HEADER} -PATH_CACHE: LRU[tuple[str, Path], tuple[Path | None, str | None]] = LRU(512) - - -def _get_file_path(rel_url: str, directory: Path) -> Path | None: - """Return the path to file on disk or None.""" - filename = Path(rel_url) - if filename.anchor: - # rel_url is an absolute name like - # /static/\\machine_name\c$ or /static/D:\path - # where the static dir is totally different - raise HTTPForbidden - filepath: Path = directory.joinpath(filename).resolve() - filepath.relative_to(directory) - # on opening a dir, load its contents if allowed - if filepath.is_dir(): - return None - if filepath.is_file(): - return filepath - raise FileNotFoundError +CACHE_HEADERS: Mapping[str, str] = {CACHE_CONTROL: CACHE_HEADER} +RESPONSE_CACHE: LRU[tuple[str, Path], tuple[Path, str]] = LRU(512) class CachingStaticResource(StaticResource): """Static Resource handler that will add cache headers.""" async def _handle(self, request: Request) -> StreamResponse: - """Return requested file from disk as a FileResponse.""" + """Wrap base handler to cache file path resolution and content type guess.""" rel_url = request.match_info["filename"] key = (rel_url, self._directory) - if (filepath_content_type := PATH_CACHE.get(key)) is None: - hass = request.app[KEY_HASS] - try: - filepath = await hass.async_add_executor_job(_get_file_path, *key) - except (ValueError, FileNotFoundError) as error: - # relatively safe - raise HTTPNotFound from error - except HTTPForbidden: - # forbidden - raise - except Exception as error: - # perm error or other kind! - request.app.logger.exception("Unexpected exception") - raise HTTPNotFound from error + response: StreamResponse - content_type: str | None = None - if filepath is not None: - content_type = (mimetypes.guess_type(rel_url))[ - 0 - ] or "application/octet-stream" - PATH_CACHE[key] = (filepath, content_type) + if key in RESPONSE_CACHE: + file_path, content_type = RESPONSE_CACHE[key] + response = FileResponse(file_path, chunk_size=self._chunk_size) + response.headers[CONTENT_TYPE] = content_type else: - filepath, content_type = filepath_content_type - - if filepath and content_type: - return FileResponse( - filepath, - chunk_size=self._chunk_size, - headers={ - hdrs.CACHE_CONTROL: CACHE_HEADER, - hdrs.CONTENT_TYPE: content_type, - }, + response = await super()._handle(request) + if not isinstance(response, FileResponse): + # Must be directory index; ignore caching + return response + file_path = response._path # noqa: SLF001 + response.content_type = ( + CONTENT_TYPES.guess_type(file_path)[0] or FALLBACK_CONTENT_TYPE ) + # Cache actual header after setter construction. + content_type = response.headers[CONTENT_TYPE] + RESPONSE_CACHE[key] = (file_path, content_type) - raise HTTPForbidden if filepath is None else HTTPNotFound + response.headers[CACHE_CONTROL] = CACHE_HEADER + return response diff --git a/tests/components/http/test_static.py b/tests/components/http/test_static.py index 52a5db5daa7..2ac7c6ded93 100644 --- a/tests/components/http/test_static.py +++ b/tests/components/http/test_static.py @@ -4,11 +4,10 @@ from http import HTTPStatus from pathlib import Path from aiohttp.test_utils import TestClient -from aiohttp.web_exceptions import HTTPForbidden import pytest from homeassistant.components.http import StaticPathConfig -from homeassistant.components.http.static import CachingStaticResource, _get_file_path +from homeassistant.components.http.static import CachingStaticResource from homeassistant.const import EVENT_HOMEASSISTANT_START from homeassistant.core import HomeAssistant from homeassistant.helpers.http import KEY_ALLOW_CONFIGURED_CORS @@ -31,37 +30,19 @@ async def mock_http_client(hass: HomeAssistant, aiohttp_client: ClientSessionGen return await aiohttp_client(hass.http.app, server_kwargs={"skip_url_asserts": True}) -@pytest.mark.parametrize( - ("url", "canonical_url"), - [ - ("//a", "//a"), - ("///a", "///a"), - ("/c:\\a\\b", "/c:%5Ca%5Cb"), - ], -) -async def test_static_path_blocks_anchors( - hass: HomeAssistant, - mock_http_client: TestClient, - tmp_path: Path, - url: str, - canonical_url: str, +async def test_static_resource_show_index( + hass: HomeAssistant, mock_http_client: TestClient, tmp_path: Path ) -> None: - """Test static paths block anchors.""" + """Test static resource will return a directory index.""" app = hass.http.app - resource = CachingStaticResource(url, str(tmp_path)) - assert resource.canonical == canonical_url + resource = CachingStaticResource("/", tmp_path, show_index=True) app.router.register_resource(resource) app[KEY_ALLOW_CONFIGURED_CORS](resource) - resp = await mock_http_client.get(canonical_url, allow_redirects=False) - assert resp.status == 403 - - # Tested directly since aiohttp will block it before - # it gets here but we want to make sure if aiohttp ever - # changes we still block it. - with pytest.raises(HTTPForbidden): - _get_file_path(canonical_url, tmp_path) + resp = await mock_http_client.get("/") + assert resp.status == 200 + assert resp.content_type == "text/html" async def test_async_register_static_paths(