From 28c07f5c437d0de20c7b9328299c9b501d400ebf Mon Sep 17 00:00:00 2001 From: Bram Kragten Date: Tue, 9 Nov 2021 18:30:51 +0100 Subject: [PATCH] Fix CORS (#59360) * Fix CORS * rename * Update view.py --- homeassistant/components/http/__init__.py | 24 ++++++++++++-------- homeassistant/components/http/cors.py | 15 ++++-------- homeassistant/components/http/view.py | 8 +++---- tests/components/http/test_cors.py | 2 +- tests/components/http/test_data_validator.py | 1 + 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/homeassistant/components/http/__init__.py b/homeassistant/components/http/__init__.py index 19e3437b79c..15766acdd4c 100644 --- a/homeassistant/components/http/__init__.py +++ b/homeassistant/components/http/__init__.py @@ -298,21 +298,24 @@ class HomeAssistantHTTP: # Should be instance of aiohttp.web_exceptions._HTTPMove. raise redirect_exc(redirect_to) # type: ignore[arg-type,misc] - self.app.router.add_route("GET", url, redirect) + self.app["allow_configured_cors"]( + self.app.router.add_route("GET", url, redirect) + ) def register_static_path( self, url_path: str, path: str, cache_headers: bool = True - ) -> web.FileResponse | None: + ) -> None: """Register a folder or file to serve as a static path.""" if os.path.isdir(path): if cache_headers: - resource: type[ - CachingStaticResource | web.StaticResource - ] = CachingStaticResource + resource: CachingStaticResource | web.StaticResource = ( + CachingStaticResource(url_path, path) + ) else: - resource = web.StaticResource - self.app.router.register_resource(resource(url_path, path)) - return None + resource = web.StaticResource(url_path, path) + self.app.router.register_resource(resource) + self.app["allow_configured_cors"](resource) + return async def serve_file(request: web.Request) -> web.FileResponse: """Serve file from disk.""" @@ -320,8 +323,9 @@ class HomeAssistantHTTP: return web.FileResponse(path, headers=CACHE_HEADERS) return web.FileResponse(path) - self.app.router.add_route("GET", url_path, serve_file) - return None + self.app["allow_configured_cors"]( + self.app.router.add_route("GET", url_path, serve_file) + ) async def start(self) -> None: """Start the aiohttp server.""" diff --git a/homeassistant/components/http/cors.py b/homeassistant/components/http/cors.py index d9310c8937f..97a0530b703 100644 --- a/homeassistant/components/http/cors.py +++ b/homeassistant/components/http/cors.py @@ -70,7 +70,7 @@ def setup_cors(app: Application, origins: list[str]) -> None: cors.add(route, config) cors_added.add(path_str) - app["allow_cors"] = lambda route: _allow_cors( + app["allow_all_cors"] = lambda route: _allow_cors( route, { "*": aiohttp_cors.ResourceOptions( @@ -79,12 +79,7 @@ def setup_cors(app: Application, origins: list[str]) -> None: }, ) - if not origins: - return - - async def cors_startup(app: Application) -> None: - """Initialize CORS when app starts up.""" - for resource in list(app.router.resources()): - _allow_cors(resource) - - app.on_startup.append(cors_startup) + if origins: + app["allow_configured_cors"] = _allow_cors + else: + app["allow_configured_cors"] = lambda _: None diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index 6123f83563c..aeb610d265e 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -94,11 +94,11 @@ class HomeAssistantView: for url in urls: routes.append(router.add_route(method, url, handler)) - if not self.cors_allowed: - return - + allow_cors = ( + app["allow_all_cors"] if self.cors_allowed else app["allow_configured_cors"] + ) for route in routes: - app["allow_cors"](route) + allow_cors(route) def request_handler_factory( diff --git a/tests/components/http/test_cors.py b/tests/components/http/test_cors.py index 141627c7763..599df194195 100644 --- a/tests/components/http/test_cors.py +++ b/tests/components/http/test_cors.py @@ -52,8 +52,8 @@ async def mock_handler(request): def client(loop, aiohttp_client): """Fixture to set up a web.Application.""" app = web.Application() - app.router.add_get("/", mock_handler) setup_cors(app, [TRUSTED_ORIGIN]) + app["allow_configured_cors"](app.router.add_get("/", mock_handler)) return loop.run_until_complete(aiohttp_client(app)) diff --git a/tests/components/http/test_data_validator.py b/tests/components/http/test_data_validator.py index 4ff6d3e8c2a..28c43230c43 100644 --- a/tests/components/http/test_data_validator.py +++ b/tests/components/http/test_data_validator.py @@ -13,6 +13,7 @@ async def get_client(aiohttp_client, validator): """Generate a client that hits a view decorated with validator.""" app = web.Application() app["hass"] = Mock(is_stopping=False) + app["allow_configured_cors"] = lambda _: None class TestView(HomeAssistantView): url = "/"