* Fix CORS

* rename

* Update view.py
This commit is contained in:
Bram Kragten 2021-11-09 18:30:51 +01:00 committed by GitHub
parent 7e81c6a591
commit 28c07f5c43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 25 additions and 25 deletions

View File

@ -298,21 +298,24 @@ class HomeAssistantHTTP:
# Should be instance of aiohttp.web_exceptions._HTTPMove. # Should be instance of aiohttp.web_exceptions._HTTPMove.
raise redirect_exc(redirect_to) # type: ignore[arg-type,misc] raise redirect_exc(redirect_to) # type: ignore[arg-type,misc]
self.app.router.add_route("GET", url, redirect) self.app["allow_configured_cors"](
self.app.router.add_route("GET", url, redirect)
)
def register_static_path( def register_static_path(
self, url_path: str, path: str, cache_headers: bool = True self, url_path: str, path: str, cache_headers: bool = True
) -> web.FileResponse | None: ) -> None:
"""Register a folder or file to serve as a static path.""" """Register a folder or file to serve as a static path."""
if os.path.isdir(path): if os.path.isdir(path):
if cache_headers: if cache_headers:
resource: type[ resource: CachingStaticResource | web.StaticResource = (
CachingStaticResource | web.StaticResource CachingStaticResource(url_path, path)
] = CachingStaticResource )
else: else:
resource = web.StaticResource resource = web.StaticResource(url_path, path)
self.app.router.register_resource(resource(url_path, path)) self.app.router.register_resource(resource)
return None self.app["allow_configured_cors"](resource)
return
async def serve_file(request: web.Request) -> web.FileResponse: async def serve_file(request: web.Request) -> web.FileResponse:
"""Serve file from disk.""" """Serve file from disk."""
@ -320,8 +323,9 @@ class HomeAssistantHTTP:
return web.FileResponse(path, headers=CACHE_HEADERS) return web.FileResponse(path, headers=CACHE_HEADERS)
return web.FileResponse(path) return web.FileResponse(path)
self.app.router.add_route("GET", url_path, serve_file) self.app["allow_configured_cors"](
return None self.app.router.add_route("GET", url_path, serve_file)
)
async def start(self) -> None: async def start(self) -> None:
"""Start the aiohttp server.""" """Start the aiohttp server."""

View File

@ -70,7 +70,7 @@ def setup_cors(app: Application, origins: list[str]) -> None:
cors.add(route, config) cors.add(route, config)
cors_added.add(path_str) cors_added.add(path_str)
app["allow_cors"] = lambda route: _allow_cors( app["allow_all_cors"] = lambda route: _allow_cors(
route, route,
{ {
"*": aiohttp_cors.ResourceOptions( "*": aiohttp_cors.ResourceOptions(
@ -79,12 +79,7 @@ def setup_cors(app: Application, origins: list[str]) -> None:
}, },
) )
if not origins: if origins:
return app["allow_configured_cors"] = _allow_cors
else:
async def cors_startup(app: Application) -> None: app["allow_configured_cors"] = lambda _: None
"""Initialize CORS when app starts up."""
for resource in list(app.router.resources()):
_allow_cors(resource)
app.on_startup.append(cors_startup)

View File

@ -94,11 +94,11 @@ class HomeAssistantView:
for url in urls: for url in urls:
routes.append(router.add_route(method, url, handler)) routes.append(router.add_route(method, url, handler))
if not self.cors_allowed: allow_cors = (
return app["allow_all_cors"] if self.cors_allowed else app["allow_configured_cors"]
)
for route in routes: for route in routes:
app["allow_cors"](route) allow_cors(route)
def request_handler_factory( def request_handler_factory(

View File

@ -52,8 +52,8 @@ async def mock_handler(request):
def client(loop, aiohttp_client): def client(loop, aiohttp_client):
"""Fixture to set up a web.Application.""" """Fixture to set up a web.Application."""
app = web.Application() app = web.Application()
app.router.add_get("/", mock_handler)
setup_cors(app, [TRUSTED_ORIGIN]) setup_cors(app, [TRUSTED_ORIGIN])
app["allow_configured_cors"](app.router.add_get("/", mock_handler))
return loop.run_until_complete(aiohttp_client(app)) return loop.run_until_complete(aiohttp_client(app))

View File

@ -13,6 +13,7 @@ async def get_client(aiohttp_client, validator):
"""Generate a client that hits a view decorated with validator.""" """Generate a client that hits a view decorated with validator."""
app = web.Application() app = web.Application()
app["hass"] = Mock(is_stopping=False) app["hass"] = Mock(is_stopping=False)
app["allow_configured_cors"] = lambda _: None
class TestView(HomeAssistantView): class TestView(HomeAssistantView):
url = "/" url = "/"