Adapt static resource handler to aiohttp 3.10 (#123166)

This commit is contained in:
Steve Repsher 2024-08-06 10:17:54 -04:00 committed by GitHub
parent 4627a565d3
commit 9414e6d472
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 84 deletions

View File

@ -3,81 +3,46 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import mimetypes
from pathlib import Path from pathlib import Path
from typing import Final from typing import Final
from aiohttp import hdrs from aiohttp.hdrs import CACHE_CONTROL, CONTENT_TYPE
from aiohttp.web import FileResponse, Request, StreamResponse from aiohttp.web import FileResponse, Request, StreamResponse
from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound from aiohttp.web_fileresponse import CONTENT_TYPES, FALLBACK_CONTENT_TYPE
from aiohttp.web_urldispatcher import StaticResource from aiohttp.web_urldispatcher import StaticResource
from lru import LRU from lru import LRU
from .const import KEY_HASS
CACHE_TIME: Final = 31 * 86400 # = 1 month CACHE_TIME: Final = 31 * 86400 # = 1 month
CACHE_HEADER = f"public, max-age={CACHE_TIME}" CACHE_HEADER = f"public, max-age={CACHE_TIME}"
CACHE_HEADERS: Mapping[str, str] = {hdrs.CACHE_CONTROL: CACHE_HEADER} CACHE_HEADERS: Mapping[str, str] = {CACHE_CONTROL: CACHE_HEADER}
PATH_CACHE: LRU[tuple[str, Path], tuple[Path | None, str | None]] = LRU(512) RESPONSE_CACHE: LRU[tuple[str, Path], tuple[Path, str]] = LRU(512)
def _get_file_path(rel_url: str, directory: Path) -> Path | None:
"""Return the path to file on disk or None."""
filename = Path(rel_url)
if filename.anchor:
# rel_url is an absolute name like
# /static/\\machine_name\c$ or /static/D:\path
# where the static dir is totally different
raise HTTPForbidden
filepath: Path = directory.joinpath(filename).resolve()
filepath.relative_to(directory)
# on opening a dir, load its contents if allowed
if filepath.is_dir():
return None
if filepath.is_file():
return filepath
raise FileNotFoundError
class CachingStaticResource(StaticResource): class CachingStaticResource(StaticResource):
"""Static Resource handler that will add cache headers.""" """Static Resource handler that will add cache headers."""
async def _handle(self, request: Request) -> StreamResponse: async def _handle(self, request: Request) -> StreamResponse:
"""Return requested file from disk as a FileResponse.""" """Wrap base handler to cache file path resolution and content type guess."""
rel_url = request.match_info["filename"] rel_url = request.match_info["filename"]
key = (rel_url, self._directory) key = (rel_url, self._directory)
if (filepath_content_type := PATH_CACHE.get(key)) is None: response: StreamResponse
hass = request.app[KEY_HASS]
try:
filepath = await hass.async_add_executor_job(_get_file_path, *key)
except (ValueError, FileNotFoundError) as error:
# relatively safe
raise HTTPNotFound from error
except HTTPForbidden:
# forbidden
raise
except Exception as error:
# perm error or other kind!
request.app.logger.exception("Unexpected exception")
raise HTTPNotFound from error
content_type: str | None = None if key in RESPONSE_CACHE:
if filepath is not None: file_path, content_type = RESPONSE_CACHE[key]
content_type = (mimetypes.guess_type(rel_url))[ response = FileResponse(file_path, chunk_size=self._chunk_size)
0 response.headers[CONTENT_TYPE] = content_type
] or "application/octet-stream"
PATH_CACHE[key] = (filepath, content_type)
else: else:
filepath, content_type = filepath_content_type response = await super()._handle(request)
if not isinstance(response, FileResponse):
if filepath and content_type: # Must be directory index; ignore caching
return FileResponse( return response
filepath, file_path = response._path # noqa: SLF001
chunk_size=self._chunk_size, response.content_type = (
headers={ CONTENT_TYPES.guess_type(file_path)[0] or FALLBACK_CONTENT_TYPE
hdrs.CACHE_CONTROL: CACHE_HEADER,
hdrs.CONTENT_TYPE: content_type,
},
) )
# Cache actual header after setter construction.
content_type = response.headers[CONTENT_TYPE]
RESPONSE_CACHE[key] = (file_path, content_type)
raise HTTPForbidden if filepath is None else HTTPNotFound response.headers[CACHE_CONTROL] = CACHE_HEADER
return response

View File

@ -4,11 +4,10 @@ from http import HTTPStatus
from pathlib import Path from pathlib import Path
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
from aiohttp.web_exceptions import HTTPForbidden
import pytest import pytest
from homeassistant.components.http import StaticPathConfig from homeassistant.components.http import StaticPathConfig
from homeassistant.components.http.static import CachingStaticResource, _get_file_path from homeassistant.components.http.static import CachingStaticResource
from homeassistant.const import EVENT_HOMEASSISTANT_START from homeassistant.const import EVENT_HOMEASSISTANT_START
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.http import KEY_ALLOW_CONFIGURED_CORS from homeassistant.helpers.http import KEY_ALLOW_CONFIGURED_CORS
@ -31,37 +30,19 @@ async def mock_http_client(hass: HomeAssistant, aiohttp_client: ClientSessionGen
return await aiohttp_client(hass.http.app, server_kwargs={"skip_url_asserts": True}) return await aiohttp_client(hass.http.app, server_kwargs={"skip_url_asserts": True})
@pytest.mark.parametrize( async def test_static_resource_show_index(
("url", "canonical_url"), hass: HomeAssistant, mock_http_client: TestClient, tmp_path: Path
[
("//a", "//a"),
("///a", "///a"),
("/c:\\a\\b", "/c:%5Ca%5Cb"),
],
)
async def test_static_path_blocks_anchors(
hass: HomeAssistant,
mock_http_client: TestClient,
tmp_path: Path,
url: str,
canonical_url: str,
) -> None: ) -> None:
"""Test static paths block anchors.""" """Test static resource will return a directory index."""
app = hass.http.app app = hass.http.app
resource = CachingStaticResource(url, str(tmp_path)) resource = CachingStaticResource("/", tmp_path, show_index=True)
assert resource.canonical == canonical_url
app.router.register_resource(resource) app.router.register_resource(resource)
app[KEY_ALLOW_CONFIGURED_CORS](resource) app[KEY_ALLOW_CONFIGURED_CORS](resource)
resp = await mock_http_client.get(canonical_url, allow_redirects=False) resp = await mock_http_client.get("/")
assert resp.status == 403 assert resp.status == 200
assert resp.content_type == "text/html"
# Tested directly since aiohttp will block it before
# it gets here but we want to make sure if aiohttp ever
# changes we still block it.
with pytest.raises(HTTPForbidden):
_get_file_path(canonical_url, tmp_path)
async def test_async_register_static_paths( async def test_async_register_static_paths(