Allow CORS requests to token endpoint (#15519)

* Allow CORS requests to token endpoint

* Tests

* Fuck emulated hue

* Clean up

* Only cors existing methods
This commit is contained in:
Paulus Schoutsen 2018-07-19 08:37:00 +02:00 committed by GitHub
parent 22d961de70
commit 2a76a0852f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 59 additions and 23 deletions

View File

@ -241,6 +241,7 @@ class GrantTokenView(HomeAssistantView):
url = '/auth/token' url = '/auth/token'
name = 'api:auth:token' name = 'api:auth:token'
requires_auth = False requires_auth = False
cors_allowed = True
def __init__(self, retrieve_credentials): def __init__(self, retrieve_credentials):
"""Initialize the grant token view.""" """Initialize the grant token view."""

View File

@ -90,12 +90,12 @@ def setup(hass, yaml_config):
handler = None handler = None
server = None server = None
DescriptionXmlView(config).register(app.router) DescriptionXmlView(config).register(app, app.router)
HueUsernameView().register(app.router) HueUsernameView().register(app, app.router)
HueAllLightsStateView(config).register(app.router) HueAllLightsStateView(config).register(app, app.router)
HueOneLightStateView(config).register(app.router) HueOneLightStateView(config).register(app, app.router)
HueOneLightChangeView(config).register(app.router) HueOneLightChangeView(config).register(app, app.router)
HueGroupView(config).register(app.router) HueGroupView(config).register(app, app.router)
upnp_listener = UPNPResponderThread( upnp_listener = UPNPResponderThread(
config.host_ip_addr, config.listen_port, config.host_ip_addr, config.listen_port,

View File

@ -187,8 +187,7 @@ class HomeAssistantHTTP(object):
support_legacy=hass.auth.support_legacy, support_legacy=hass.auth.support_legacy,
api_password=api_password) api_password=api_password)
if cors_origins: setup_cors(app, cors_origins)
setup_cors(app, cors_origins)
app['hass'] = hass app['hass'] = hass
@ -226,7 +225,7 @@ class HomeAssistantHTTP(object):
'{0} missing required attribute "name"'.format(class_name) '{0} missing required attribute "name"'.format(class_name)
) )
view.register(self.app.router) view.register(self.app, self.app.router)
def register_redirect(self, url, redirect_to): def register_redirect(self, url, redirect_to):
"""Register a redirect with the server. """Register a redirect with the server.

View File

@ -27,6 +27,20 @@ def setup_cors(app, origins):
) for host in origins ) for host in origins
}) })
def allow_cors(route, methods):
"""Allow cors on a route."""
cors.add(route, {
'*': aiohttp_cors.ResourceOptions(
allow_headers=ALLOWED_CORS_HEADERS,
allow_methods=methods,
)
})
app['allow_cors'] = allow_cors
if not origins:
return
async def cors_startup(app): async def cors_startup(app):
"""Initialize cors when app starts up.""" """Initialize cors when app starts up."""
cors_added = set() cors_added = set()

View File

@ -26,7 +26,9 @@ class HomeAssistantView(object):
url = None url = None
extra_urls = [] extra_urls = []
requires_auth = True # Views inheriting from this class can override this # Views inheriting from this class can override this
requires_auth = True
cors_allowed = False
# pylint: disable=no-self-use # pylint: disable=no-self-use
def json(self, result, status_code=200, headers=None): def json(self, result, status_code=200, headers=None):
@ -51,10 +53,11 @@ class HomeAssistantView(object):
data['code'] = message_code data['code'] = message_code
return self.json(data, status_code, headers=headers) return self.json(data, status_code, headers=headers)
def register(self, router): def register(self, app, router):
"""Register the view with a router.""" """Register the view with a router."""
assert self.url is not None, 'No url set for view' assert self.url is not None, 'No url set for view'
urls = [self.url] + self.extra_urls urls = [self.url] + self.extra_urls
routes = []
for method in ('get', 'post', 'delete', 'put'): for method in ('get', 'post', 'delete', 'put'):
handler = getattr(self, method, None) handler = getattr(self, method, None)
@ -65,13 +68,15 @@ class HomeAssistantView(object):
handler = request_handler_factory(self, handler) handler = request_handler_factory(self, handler)
for url in urls: for url in urls:
router.add_route(method, url, handler) routes.append(
(method, router.add_route(method, url, handler))
)
# aiohttp_cors does not work with class based views if not self.cors_allowed:
# self.app.router.add_route('*', self.url, self, name=self.name) return
# for url in self.extra_urls: for method, route in routes:
# self.app.router.add_route('*', url, self) app['allow_cors'](route, [method.upper()])
def request_handler_factory(view, handler): def request_handler_factory(view, handler):

View File

@ -93,3 +93,20 @@ async def test_ws_current_user(hass, hass_ws_client, hass_access_token):
assert user_dict['name'] == user.name assert user_dict['name'] == user.name
assert user_dict['id'] == user.id assert user_dict['id'] == user.id
assert user_dict['is_owner'] == user.is_owner assert user_dict['is_owner'] == user.is_owner
async def test_cors_on_token(hass, aiohttp_client):
"""Test logging in with new user and refreshing tokens."""
client = await async_setup_auth(hass, aiohttp_client)
resp = await client.options('/auth/token', headers={
'origin': 'http://example.com',
'Access-Control-Request-Method': 'POST',
})
assert resp.headers['Access-Control-Allow-Origin'] == 'http://example.com'
assert resp.headers['Access-Control-Allow-Methods'] == 'POST'
resp = await client.post('/auth/token', headers={
'origin': 'http://example.com'
})
assert resp.headers['Access-Control-Allow-Origin'] == 'http://example.com'

View File

@ -130,10 +130,10 @@ def hue_client(loop, hass_hue, aiohttp_client):
} }
}) })
HueUsernameView().register(web_app.router) HueUsernameView().register(web_app, web_app.router)
HueAllLightsStateView(config).register(web_app.router) HueAllLightsStateView(config).register(web_app, web_app.router)
HueOneLightStateView(config).register(web_app.router) HueOneLightStateView(config).register(web_app, web_app.router)
HueOneLightChangeView(config).register(web_app.router) HueOneLightChangeView(config).register(web_app, web_app.router)
return loop.run_until_complete(aiohttp_client(web_app)) return loop.run_until_complete(aiohttp_client(web_app))

View File

@ -19,14 +19,14 @@ from homeassistant.components.http.cors import setup_cors
TRUSTED_ORIGIN = 'https://home-assistant.io' TRUSTED_ORIGIN = 'https://home-assistant.io'
async def test_cors_middleware_not_loaded_by_default(hass): async def test_cors_middleware_loaded_by_default(hass):
"""Test accessing to server from banned IP when feature is off.""" """Test accessing to server from banned IP when feature is off."""
with patch('homeassistant.components.http.setup_cors') as mock_setup: with patch('homeassistant.components.http.setup_cors') as mock_setup:
await async_setup_component(hass, 'http', { await async_setup_component(hass, 'http', {
'http': {} 'http': {}
}) })
assert len(mock_setup.mock_calls) == 0 assert len(mock_setup.mock_calls) == 1
async def test_cors_middleware_loaded_from_config(hass): async def test_cors_middleware_loaded_from_config(hass):

View File

@ -23,7 +23,7 @@ async def get_client(aiohttp_client, validator):
"""Test method.""" """Test method."""
return b'' return b''
TestView().register(app.router) TestView().register(app, app.router)
client = await aiohttp_client(app) client = await aiohttp_client(app)
return client return client