mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 09:47:13 +00:00
Fix CORS duplicate registration (#15670)
This commit is contained in:
parent
68f03dcc67
commit
169c8d793a
@ -27,30 +27,36 @@ def setup_cors(app, origins):
|
|||||||
) for host in origins
|
) for host in origins
|
||||||
})
|
})
|
||||||
|
|
||||||
def allow_cors(route, methods):
|
cors_added = set()
|
||||||
|
|
||||||
|
def _allow_cors(route, config=None):
|
||||||
"""Allow cors on a route."""
|
"""Allow cors on a route."""
|
||||||
cors.add(route, {
|
if hasattr(route, 'resource'):
|
||||||
|
path = route.resource
|
||||||
|
else:
|
||||||
|
path = route
|
||||||
|
|
||||||
|
path = path.canonical
|
||||||
|
|
||||||
|
if path in cors_added:
|
||||||
|
return
|
||||||
|
|
||||||
|
cors.add(route, config)
|
||||||
|
cors_added.add(path)
|
||||||
|
|
||||||
|
app['allow_cors'] = lambda route: _allow_cors(route, {
|
||||||
'*': aiohttp_cors.ResourceOptions(
|
'*': aiohttp_cors.ResourceOptions(
|
||||||
allow_headers=ALLOWED_CORS_HEADERS,
|
allow_headers=ALLOWED_CORS_HEADERS,
|
||||||
allow_methods=methods,
|
allow_methods='*',
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
app['allow_cors'] = allow_cors
|
|
||||||
|
|
||||||
if not origins:
|
if not origins:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def cors_startup(app):
|
async def cors_startup(app):
|
||||||
"""Initialize cors when app starts up."""
|
"""Initialize cors when app starts up."""
|
||||||
cors_added = set()
|
|
||||||
|
|
||||||
for route in list(app.router.routes()):
|
for route in list(app.router.routes()):
|
||||||
if hasattr(route, 'resource'):
|
_allow_cors(route)
|
||||||
route = route.resource
|
|
||||||
if route in cors_added:
|
|
||||||
continue
|
|
||||||
cors.add(route)
|
|
||||||
cors_added.add(route)
|
|
||||||
|
|
||||||
app.on_startup.append(cors_startup)
|
app.on_startup.append(cors_startup)
|
||||||
|
@ -69,15 +69,13 @@ class HomeAssistantView:
|
|||||||
handler = request_handler_factory(self, handler)
|
handler = request_handler_factory(self, handler)
|
||||||
|
|
||||||
for url in urls:
|
for url in urls:
|
||||||
routes.append(
|
routes.append(router.add_route(method, url, handler))
|
||||||
(method, router.add_route(method, url, handler))
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.cors_allowed:
|
if not self.cors_allowed:
|
||||||
return
|
return
|
||||||
|
|
||||||
for method, route in routes:
|
for route in routes:
|
||||||
app['allow_cors'](route, [method.upper()])
|
app['allow_cors'](route)
|
||||||
|
|
||||||
|
|
||||||
def request_handler_factory(view, handler):
|
def request_handler_factory(view, handler):
|
||||||
|
@ -14,6 +14,7 @@ import pytest
|
|||||||
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
from homeassistant.const import HTTP_HEADER_HA_AUTH
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.components.http.cors import setup_cors
|
from homeassistant.components.http.cors import setup_cors
|
||||||
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
|
|
||||||
|
|
||||||
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
TRUSTED_ORIGIN = 'https://home-assistant.io'
|
||||||
@ -96,3 +97,34 @@ async def test_cors_preflight_allowed(client):
|
|||||||
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN
|
assert req.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == TRUSTED_ORIGIN
|
||||||
assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \
|
assert req.headers[ACCESS_CONTROL_ALLOW_HEADERS] == \
|
||||||
HTTP_HEADER_HA_AUTH.upper()
|
HTTP_HEADER_HA_AUTH.upper()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cors_middleware_with_cors_allowed_view(hass):
|
||||||
|
"""Test that we can configure cors and have a cors_allowed view."""
|
||||||
|
class MyView(HomeAssistantView):
|
||||||
|
"""Test view that allows CORS."""
|
||||||
|
|
||||||
|
requires_auth = False
|
||||||
|
cors_allowed = True
|
||||||
|
|
||||||
|
def __init__(self, url, name):
|
||||||
|
"""Initialize test view."""
|
||||||
|
self.url = url
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
"""Test response."""
|
||||||
|
return "test"
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, 'http', {
|
||||||
|
'http': {
|
||||||
|
'cors_allowed_origins': ['http://home-assistant.io']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
hass.http.register_view(MyView('/api/test', 'api:test'))
|
||||||
|
hass.http.register_view(MyView('/api/test', 'api:test2'))
|
||||||
|
hass.http.register_view(MyView('/api/test2', 'api:test'))
|
||||||
|
|
||||||
|
hass.http.app._on_startup.freeze()
|
||||||
|
await hass.http.app.startup()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user