mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
Use aiohttp.AppKey for http cors keys (#112658)
This commit is contained in:
parent
9555e8764a
commit
2d701d5a7d
@ -33,6 +33,7 @@ from homeassistant.exceptions import HomeAssistantError
|
|||||||
from homeassistant.helpers import storage
|
from homeassistant.helpers import storage
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.http import (
|
from homeassistant.helpers.http import (
|
||||||
|
KEY_ALLOW_CONFIGRED_CORS,
|
||||||
KEY_AUTHENTICATED, # noqa: F401
|
KEY_AUTHENTICATED, # noqa: F401
|
||||||
KEY_HASS,
|
KEY_HASS,
|
||||||
HomeAssistantView,
|
HomeAssistantView,
|
||||||
@ -389,7 +390,7 @@ class HomeAssistantHTTP:
|
|||||||
# Should be instance of aiohttp.web_exceptions._HTTPMove.
|
# Should be instance of aiohttp.web_exceptions._HTTPMove.
|
||||||
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
|
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
|
||||||
|
|
||||||
self.app["allow_configured_cors"](
|
self.app[KEY_ALLOW_CONFIGRED_CORS](
|
||||||
self.app.router.add_route("GET", url, redirect)
|
self.app.router.add_route("GET", url, redirect)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -405,7 +406,7 @@ class HomeAssistantHTTP:
|
|||||||
else:
|
else:
|
||||||
resource = web.StaticResource(url_path, path)
|
resource = web.StaticResource(url_path, path)
|
||||||
self.app.router.register_resource(resource)
|
self.app.router.register_resource(resource)
|
||||||
self.app["allow_configured_cors"](resource)
|
self.app[KEY_ALLOW_CONFIGRED_CORS](resource)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def serve_file(request: web.Request) -> web.FileResponse:
|
async def serve_file(request: web.Request) -> web.FileResponse:
|
||||||
@ -414,7 +415,7 @@ class HomeAssistantHTTP:
|
|||||||
return web.FileResponse(path, headers=CACHE_HEADERS)
|
return web.FileResponse(path, headers=CACHE_HEADERS)
|
||||||
return web.FileResponse(path)
|
return web.FileResponse(path)
|
||||||
|
|
||||||
self.app["allow_configured_cors"](
|
self.app[KEY_ALLOW_CONFIGRED_CORS](
|
||||||
self.app.router.add_route("GET", url_path, serve_file)
|
self.app.router.add_route("GET", url_path, serve_file)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Provide CORS support for the HTTP component."""
|
"""Provide CORS support for the HTTP component."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Final
|
from typing import Final, cast
|
||||||
|
|
||||||
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
|
from aiohttp.hdrs import ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN
|
||||||
from aiohttp.web import Application
|
from aiohttp.web import Application
|
||||||
@ -15,6 +15,11 @@ from aiohttp.web_urldispatcher import (
|
|||||||
|
|
||||||
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
|
from homeassistant.const import HTTP_HEADER_X_REQUESTED_WITH
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.helpers.http import (
|
||||||
|
KEY_ALLOW_ALL_CORS,
|
||||||
|
KEY_ALLOW_CONFIGRED_CORS,
|
||||||
|
AllowCorsType,
|
||||||
|
)
|
||||||
|
|
||||||
ALLOWED_CORS_HEADERS: Final[list[str]] = [
|
ALLOWED_CORS_HEADERS: Final[list[str]] = [
|
||||||
ORIGIN,
|
ORIGIN,
|
||||||
@ -70,7 +75,7 @@ def setup_cors(app: Application, origins: list[str]) -> None:
|
|||||||
cors.add(route, config)
|
cors.add(route, config)
|
||||||
cors_added.add(path_str)
|
cors_added.add(path_str)
|
||||||
|
|
||||||
app["allow_all_cors"] = lambda route: _allow_cors(
|
app[KEY_ALLOW_ALL_CORS] = lambda route: _allow_cors(
|
||||||
route,
|
route,
|
||||||
{
|
{
|
||||||
"*": aiohttp_cors.ResourceOptions(
|
"*": aiohttp_cors.ResourceOptions(
|
||||||
@ -80,6 +85,6 @@ def setup_cors(app: Application, origins: list[str]) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if origins:
|
if origins:
|
||||||
app["allow_configured_cors"] = _allow_cors
|
app[KEY_ALLOW_CONFIGRED_CORS] = cast(AllowCorsType, _allow_cors)
|
||||||
else:
|
else:
|
||||||
app["allow_configured_cors"] = lambda _: None
|
app[KEY_ALLOW_CONFIGRED_CORS] = lambda _: None
|
||||||
|
@ -16,7 +16,7 @@ from aiohttp.web_exceptions import (
|
|||||||
HTTPInternalServerError,
|
HTTPInternalServerError,
|
||||||
HTTPUnauthorized,
|
HTTPUnauthorized,
|
||||||
)
|
)
|
||||||
from aiohttp.web_urldispatcher import AbstractRoute
|
from aiohttp.web_urldispatcher import AbstractResource, AbstractRoute
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import exceptions
|
from homeassistant import exceptions
|
||||||
@ -29,7 +29,10 @@ from .json import find_paths_unserializable_data, json_bytes, json_dumps
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
AllowCorsType = Callable[[AbstractRoute | AbstractResource], None]
|
||||||
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
||||||
|
KEY_ALLOW_ALL_CORS = AppKey[AllowCorsType]("allow_all_cors")
|
||||||
|
KEY_ALLOW_CONFIGRED_CORS = AppKey[AllowCorsType]("allow_configured_cors")
|
||||||
KEY_HASS: AppKey[HomeAssistant] = AppKey("hass")
|
KEY_HASS: AppKey[HomeAssistant] = AppKey("hass")
|
||||||
|
|
||||||
current_request: ContextVar[Request | None] = ContextVar(
|
current_request: ContextVar[Request | None] = ContextVar(
|
||||||
@ -176,9 +179,9 @@ class HomeAssistantView:
|
|||||||
|
|
||||||
# Use `get` because CORS middleware is not be loaded in emulated_hue
|
# Use `get` because CORS middleware is not be loaded in emulated_hue
|
||||||
if self.cors_allowed:
|
if self.cors_allowed:
|
||||||
allow_cors = app.get("allow_all_cors")
|
allow_cors = app.get(KEY_ALLOW_ALL_CORS)
|
||||||
else:
|
else:
|
||||||
allow_cors = app.get("allow_configured_cors")
|
allow_cors = app.get(KEY_ALLOW_CONFIGRED_CORS)
|
||||||
|
|
||||||
if allow_cors:
|
if allow_cors:
|
||||||
for route in routes:
|
for route in routes:
|
||||||
|
@ -17,6 +17,7 @@ import pytest
|
|||||||
from homeassistant.components.http.cors import setup_cors
|
from homeassistant.components.http.cors import setup_cors
|
||||||
from homeassistant.components.http.view import HomeAssistantView
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.http import KEY_ALLOW_CONFIGRED_CORS
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import HTTP_HEADER_HA_AUTH
|
from . import HTTP_HEADER_HA_AUTH
|
||||||
@ -56,7 +57,7 @@ def client(event_loop, aiohttp_client):
|
|||||||
"""Fixture to set up a web.Application."""
|
"""Fixture to set up a web.Application."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
setup_cors(app, [TRUSTED_ORIGIN])
|
setup_cors(app, [TRUSTED_ORIGIN])
|
||||||
app["allow_configured_cors"](app.router.add_get("/", mock_handler))
|
app[KEY_ALLOW_CONFIGRED_CORS](app.router.add_get("/", mock_handler))
|
||||||
return event_loop.run_until_complete(aiohttp_client(app))
|
return event_loop.run_until_complete(aiohttp_client(app))
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
|
from homeassistant.helpers.http import KEY_ALLOW_CONFIGRED_CORS
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
@ -15,7 +16,7 @@ async def get_client(aiohttp_client, validator):
|
|||||||
"""Generate a client that hits a view decorated with validator."""
|
"""Generate a client that hits a view decorated with validator."""
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app[KEY_HASS] = Mock(is_stopping=False)
|
app[KEY_HASS] = Mock(is_stopping=False)
|
||||||
app["allow_configured_cors"] = lambda _: None
|
app[KEY_ALLOW_CONFIGRED_CORS] = lambda _: None
|
||||||
|
|
||||||
class TestView(HomeAssistantView):
|
class TestView(HomeAssistantView):
|
||||||
url = "/"
|
url = "/"
|
||||||
|
@ -9,6 +9,7 @@ import pytest
|
|||||||
|
|
||||||
from homeassistant.components.http.static import CachingStaticResource, _get_file_path
|
from homeassistant.components.http.static import CachingStaticResource, _get_file_path
|
||||||
from homeassistant.core import EVENT_HOMEASSISTANT_START, HomeAssistant
|
from homeassistant.core import EVENT_HOMEASSISTANT_START, HomeAssistant
|
||||||
|
from homeassistant.helpers.http import KEY_ALLOW_CONFIGRED_CORS
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator
|
from tests.typing import ClientSessionGenerator
|
||||||
@ -49,7 +50,7 @@ async def test_static_path_blocks_anchors(
|
|||||||
resource = CachingStaticResource(url, str(tmp_path))
|
resource = CachingStaticResource(url, str(tmp_path))
|
||||||
assert resource.canonical == canonical_url
|
assert resource.canonical == canonical_url
|
||||||
app.router.register_resource(resource)
|
app.router.register_resource(resource)
|
||||||
app["allow_configured_cors"](resource)
|
app[KEY_ALLOW_CONFIGRED_CORS](resource)
|
||||||
|
|
||||||
resp = await mock_http_client.get(canonical_url, allow_redirects=False)
|
resp = await mock_http_client.get(canonical_url, allow_redirects=False)
|
||||||
assert resp.status == 403
|
assert resp.status == 403
|
||||||
|
Loading…
x
Reference in New Issue
Block a user