diff --git a/homeassistant/components/http/cors.py b/homeassistant/components/http/cors.py index 419b62be2c6..880cc47ac0d 100644 --- a/homeassistant/components/http/cors.py +++ b/homeassistant/components/http/cors.py @@ -1,5 +1,5 @@ """Provide CORS support for the HTTP component.""" -from aiohttp.web_urldispatcher import Resource, ResourceRoute +from aiohttp.web_urldispatcher import Resource, ResourceRoute, StaticResource from aiohttp.hdrs import ACCEPT, CONTENT_TYPE, ORIGIN, AUTHORIZATION from homeassistant.const import ( @@ -9,7 +9,7 @@ from homeassistant.core import callback ALLOWED_CORS_HEADERS = [ ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE, HTTP_HEADER_HA_AUTH, AUTHORIZATION] -VALID_CORS_TYPES = (Resource, ResourceRoute) +VALID_CORS_TYPES = (Resource, ResourceRoute, StaticResource) @callback @@ -56,7 +56,7 @@ def setup_cors(app, origins): async def cors_startup(app): """Initialize CORS when app starts up.""" - for route in list(app.router.routes()): - _allow_cors(route) + for resource in list(app.router.resources()): + _allow_cors(resource) app.on_startup.append(cors_startup) diff --git a/tests/components/http/test_cors.py b/tests/components/http/test_cors.py index d9fa6c11309..46a2766e541 100644 --- a/tests/components/http/test_cors.py +++ b/tests/components/http/test_cors.py @@ -1,4 +1,5 @@ """Test cors for the HTTP component.""" +from pathlib import Path from unittest.mock import patch from aiohttp import web @@ -152,3 +153,22 @@ async def test_cors_works_with_frontend(hass, hass_client): client = await hass_client() resp = await client.get('/') assert resp.status == 200 + + +async def test_cors_on_static_files(hass, hass_client): + """Test that we enable CORS for static files.""" + assert await async_setup_component(hass, 'frontend', { + 'http': { + 'cors_allowed_origins': ['http://www.example.com'] + } + }) + hass.http.register_static_path('/something', Path(__file__).parent) + + client = await hass_client() + resp = await client.options('/something/__init__.py', headers={ + 'origin': 'http://www.example.com', + ACCESS_CONTROL_REQUEST_METHOD: 'GET', + }) + assert resp.status == 200 + assert resp.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \ + 'http://www.example.com'