Add CORS fixes to support OPTIONS preflight requests. (#2773)

* Add CORS fixes to support OPTIONS preflight requests.

* Add CORS tests

* Fix formatting
This commit is contained in:
Robbie Trencheny 2016-08-13 11:49:44 -07:00 committed by Paulus Schoutsen
parent 176a078b3c
commit 7882ce1afd
3 changed files with 78 additions and 5 deletions

View File

@ -453,6 +453,10 @@ class HomeAssistantView(object):
"""Handle request to url.""" """Handle request to url."""
from werkzeug.exceptions import MethodNotAllowed, Unauthorized from werkzeug.exceptions import MethodNotAllowed, Unauthorized
if request.method == "OPTIONS":
# For CORS preflight requests.
return self.options(request)
try: try:
handler = getattr(self, request.method.lower()) handler = getattr(self, request.method.lower())
except AttributeError: except AttributeError:

View File

@ -263,7 +263,8 @@ HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"
HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers" HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"
ALLOWED_CORS_HEADERS = [HTTP_HEADER_ORIGIN, HTTP_HEADER_ACCEPT, ALLOWED_CORS_HEADERS = [HTTP_HEADER_ORIGIN, HTTP_HEADER_ACCEPT,
HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_CONTENT_TYPE] HTTP_HEADER_X_REQUESTED_WITH, HTTP_HEADER_CONTENT_TYPE,
HTTP_HEADER_HA_AUTH]
CONTENT_TYPE_JSON = "application/json" CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_MULTIPART = 'multipart/x-mixed-replace; boundary={}' CONTENT_TYPE_MULTIPART = 'multipart/x-mixed-replace; boundary={}'

View File

@ -12,12 +12,15 @@ from tests.common import get_test_instance_port, get_test_home_assistant
API_PASSWORD = "test1234" API_PASSWORD = "test1234"
SERVER_PORT = get_test_instance_port() SERVER_PORT = get_test_instance_port()
HTTP_BASE_URL = "http://127.0.0.1:{}".format(SERVER_PORT) HTTP_BASE = "127.0.0.1:{}".format(SERVER_PORT)
HTTP_BASE_URL = "http://{}".format(HTTP_BASE)
HA_HEADERS = { HA_HEADERS = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD, const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON, const.HTTP_HEADER_CONTENT_TYPE: const.CONTENT_TYPE_JSON,
} }
CORS_ORIGINS = [HTTP_BASE_URL, HTTP_BASE]
hass = None hass = None
@ -38,7 +41,8 @@ def setUpModule(): # pylint: disable=invalid-name
bootstrap.setup_component( bootstrap.setup_component(
hass, http.DOMAIN, hass, http.DOMAIN,
{http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD, {http.DOMAIN: {http.CONF_API_PASSWORD: API_PASSWORD,
http.CONF_SERVER_PORT: SERVER_PORT}}) http.CONF_SERVER_PORT: SERVER_PORT,
http.CONF_CORS_ORIGINS: CORS_ORIGINS}})
bootstrap.setup_component(hass, 'api') bootstrap.setup_component(hass, 'api')
@ -61,7 +65,7 @@ class TestHttp:
assert req.status_code == 401 assert req.status_code == 401
def test_access_denied_with_wrong_password_in_header(self): def test_access_denied_with_wrong_password_in_header(self):
"""Test ascces with wrong password.""" """Test access with wrong password."""
req = requests.get( req = requests.get(
_url(const.URL_API), _url(const.URL_API),
headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'}) headers={const.HTTP_HEADER_HA_AUTH: 'wrongpassword'})
@ -86,7 +90,7 @@ class TestHttp:
assert API_PASSWORD not in logs assert API_PASSWORD not in logs
def test_access_denied_with_wrong_password_in_url(self): def test_access_denied_with_wrong_password_in_url(self):
"""Test ascces with wrong password.""" """Test access with wrong password."""
req = requests.get(_url(const.URL_API), req = requests.get(_url(const.URL_API),
params={'api_password': 'wrongpassword'}) params={'api_password': 'wrongpassword'})
@ -107,3 +111,67 @@ class TestHttp:
# assert const.URL_API in logs # assert const.URL_API in logs
assert API_PASSWORD not in logs assert API_PASSWORD not in logs
def test_cors_allowed_with_password_in_url(self):
"""Test cross origin resource sharing with password in url."""
req = requests.get(_url(const.URL_API),
params={'api_password': API_PASSWORD},
headers={const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL})
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
all_allow_headers = ", ".join(const.ALLOWED_CORS_HEADERS)
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == all_allow_headers
def test_cors_allowed_with_password_in_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD,
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL
}
req = requests.get(_url(const.URL_API),
headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
all_allow_headers = ", ".join(const.ALLOWED_CORS_HEADERS)
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == all_allow_headers
def test_cors_denied_without_origin_header(self):
"""Test cross origin resource sharing with password in header."""
headers = {
const.HTTP_HEADER_HA_AUTH: API_PASSWORD
}
req = requests.get(_url(const.URL_API),
headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
assert req.status_code == 200
assert allow_origin not in req.headers
assert allow_headers not in req.headers
def test_cors_preflight_allowed(self):
"""Test cross origin resource sharing preflight (OPTIONS) request."""
headers = {
const.HTTP_HEADER_ORIGIN: HTTP_BASE_URL,
'Access-Control-Request-Method': 'GET',
'Access-Control-Request-Headers': 'x-ha-access'
}
req = requests.options(_url(const.URL_API),
headers=headers)
allow_origin = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_ORIGIN
allow_headers = const.HTTP_HEADER_ACCESS_CONTROL_ALLOW_HEADERS
all_allow_headers = ", ".join(const.ALLOWED_CORS_HEADERS)
assert req.status_code == 200
assert req.headers.get(allow_origin) == HTTP_BASE_URL
assert req.headers.get(allow_headers) == all_allow_headers