diff --git a/homeassistant/components/auth/indieauth.py b/homeassistant/components/auth/indieauth.py index ef7f8a9b292..48f7ab06ab4 100644 --- a/homeassistant/components/auth/indieauth.py +++ b/homeassistant/components/auth/indieauth.py @@ -1,6 +1,10 @@ """Helpers to resolve client ID/secret.""" +import asyncio +from html.parser import HTMLParser from ipaddress import ip_address, ip_network -from urllib.parse import urlparse +from urllib.parse import urlparse, urljoin + +from aiohttp.client_exceptions import ClientError # IP addresses of loopback interfaces ALLOWED_IPS = ( @@ -16,7 +20,7 @@ ALLOWED_NETWORKS = ( ) -def verify_redirect_uri(client_id, redirect_uri): +async def verify_redirect_uri(hass, client_id, redirect_uri): """Verify that the client and redirect uri match.""" try: client_id_parts = _parse_client_id(client_id) @@ -25,16 +29,75 @@ def verify_redirect_uri(client_id, redirect_uri): redirect_parts = _parse_url(redirect_uri) - # IndieAuth 4.2.2 allows for redirect_uri to be on different domain - # but needs to be specified in link tag when fetching `client_id`. - # This is not implemented. - # Verify redirect url and client url have same scheme and domain. - return ( + is_valid = ( client_id_parts.scheme == redirect_parts.scheme and client_id_parts.netloc == redirect_parts.netloc ) + if is_valid: + return True + + # IndieAuth 4.2.2 allows for redirect_uri to be on different domain + # but needs to be specified in link tag when fetching `client_id`. + redirect_uris = await fetch_redirect_uris(hass, client_id) + return redirect_uri in redirect_uris + + +class LinkTagParser(HTMLParser): + """Parser to find link tags.""" + + def __init__(self, rel): + """Initialize a link tag parser.""" + super().__init__() + self.rel = rel + self.found = [] + + def handle_starttag(self, tag, attrs): + """Handle finding a start tag.""" + if tag != 'link': + return + + attrs = dict(attrs) + + if attrs.get('rel') == self.rel: + self.found.append(attrs.get('href')) + + +async def fetch_redirect_uris(hass, url): + """Find link tag with redirect_uri values. + + IndieAuth 4.2.2 + + The client SHOULD publish one or more tags or Link HTTP headers with + a rel attribute of redirect_uri at the client_id URL. + + We limit to the first 10kB of the page. + + We do not implement extracting redirect uris from headers. + """ + session = hass.helpers.aiohttp_client.async_get_clientsession() + parser = LinkTagParser('redirect_uri') + chunks = 0 + try: + resp = await session.get(url, timeout=5) + + async for data in resp.content.iter_chunked(1024): + parser.feed(data.decode()) + chunks += 1 + + if chunks == 10: + break + + except (asyncio.TimeoutError, ClientError): + pass + + # Authorization endpoints verifying that a redirect_uri is allowed for use + # by a client MUST look for an exact match of the given redirect_uri in the + # request against the list of redirect_uris discovered after resolving any + # relative URLs. + return [urljoin(url, found) for found in parser.found] + def verify_client_id(client_id): """Verify that the client id is valid.""" diff --git a/homeassistant/components/auth/login_flow.py b/homeassistant/components/auth/login_flow.py index bced421d6f9..8b983b6d19f 100644 --- a/homeassistant/components/auth/login_flow.py +++ b/homeassistant/components/auth/login_flow.py @@ -142,8 +142,8 @@ class LoginFlowIndexView(HomeAssistantView): @log_invalid_auth async def post(self, request, data): """Create a new login flow.""" - if not indieauth.verify_redirect_uri(data['client_id'], - data['redirect_uri']): + if not await indieauth.verify_redirect_uri( + request.app['hass'], data['client_id'], data['redirect_uri']): return self.json_message('invalid client id or redirect uri', 400) if isinstance(data['handler'], list): diff --git a/tests/components/auth/test_indieauth.py b/tests/components/auth/test_indieauth.py index 7bd720ddf70..75e61af2e71 100644 --- a/tests/components/auth/test_indieauth.py +++ b/tests/components/auth/test_indieauth.py @@ -1,8 +1,12 @@ """Tests for the client validator.""" -from homeassistant.components.auth import indieauth +from unittest.mock import patch import pytest +from homeassistant.components.auth import indieauth + +from tests.common import mock_coro + def test_client_id_scheme(): """Test we enforce valid scheme.""" @@ -84,27 +88,65 @@ def test_parse_url_path(): assert indieauth._parse_url('http://ex.com').path == '/' -def test_verify_redirect_uri(): +async def test_verify_redirect_uri(): """Test that we verify redirect uri correctly.""" - assert indieauth.verify_redirect_uri( + assert await indieauth.verify_redirect_uri( + None, 'http://ex.com', 'http://ex.com/callback' ) - # Different domain - assert not indieauth.verify_redirect_uri( - 'http://ex.com', - 'http://different.com/callback' - ) + with patch.object(indieauth, 'fetch_redirect_uris', + side_effect=lambda *_: mock_coro([])): + # Different domain + assert not await indieauth.verify_redirect_uri( + None, + 'http://ex.com', + 'http://different.com/callback' + ) - # Different scheme - assert not indieauth.verify_redirect_uri( - 'http://ex.com', - 'https://ex.com/callback' - ) + # Different scheme + assert not await indieauth.verify_redirect_uri( + None, + 'http://ex.com', + 'https://ex.com/callback' + ) - # Different subdomain - assert not indieauth.verify_redirect_uri( - 'https://sub1.ex.com', - 'https://sub2.ex.com/callback' - ) + # Different subdomain + assert not await indieauth.verify_redirect_uri( + None, + 'https://sub1.ex.com', + 'https://sub2.ex.com/callback' + ) + + +async def test_find_link_tag(hass, aioclient_mock): + """Test finding link tag.""" + aioclient_mock.get("http://127.0.0.1:8000", text=""" + + + + + + + + ... + +""") + redirect_uris = await indieauth.fetch_redirect_uris( + hass, "http://127.0.0.1:8000") + + assert redirect_uris == [ + "hass://oauth2_redirect", + "http://127.0.0.1:8000/beer", + ] + + +async def test_find_link_tag_max_size(hass, aioclient_mock): + """Test finding link tag.""" + text = ("0" * 1024 * 10) + '' + aioclient_mock.get("http://127.0.0.1:8000", text=text) + redirect_uris = await indieauth.fetch_redirect_uris( + hass, "http://127.0.0.1:8000") + + assert redirect_uris == [] diff --git a/tests/helpers/test_aiohttp_client.py b/tests/helpers/test_aiohttp_client.py index 28bb31c8482..ccfe1b1aff9 100644 --- a/tests/helpers/test_aiohttp_client.py +++ b/tests/helpers/test_aiohttp_client.py @@ -135,9 +135,8 @@ class TestHelpersAiohttpClient(unittest.TestCase): @asyncio.coroutine def test_async_aiohttp_proxy_stream(aioclient_mock, camera_client): """Test that it fetches the given url.""" - aioclient_mock.get('http://example.com/mjpeg_stream', content=[ - b'Frame1', b'Frame2', b'Frame3' - ]) + aioclient_mock.get('http://example.com/mjpeg_stream', + content=b'Frame1Frame2Frame3') resp = yield from camera_client.get( '/api/camera_proxy_stream/camera.config_test') @@ -145,7 +144,7 @@ def test_async_aiohttp_proxy_stream(aioclient_mock, camera_client): assert resp.status == 200 assert aioclient_mock.call_count == 1 body = yield from resp.text() - assert body == 'Frame3Frame2Frame1' + assert body == 'Frame1Frame2Frame3' @asyncio.coroutine diff --git a/tests/test_util/aiohttp.py b/tests/test_util/aiohttp.py index 0296b8c2fba..813eb84707c 100644 --- a/tests/test_util/aiohttp.py +++ b/tests/test_util/aiohttp.py @@ -7,6 +7,7 @@ from unittest import mock from urllib.parse import parse_qs from aiohttp import ClientSession +from aiohttp.streams import StreamReader from yarl import URL from aiohttp.client_exceptions import ClientResponseError @@ -14,6 +15,15 @@ from aiohttp.client_exceptions import ClientResponseError retype = type(re.compile('')) +def mock_stream(data): + """Mock a stream with data.""" + protocol = mock.Mock(_reading_paused=False) + stream = StreamReader(protocol) + stream.feed_data(data) + stream.feed_eof() + return stream + + class AiohttpClientMocker: """Mock Aiohttp client requests.""" @@ -45,7 +55,7 @@ class AiohttpClientMocker: if not isinstance(url, retype): url = URL(url) if params: - url = url.with_query(params) + url = url.with_query(params) self._mocks.append(AiohttpClientMockResponse( method, url, status, content, cookies, exc, headers)) @@ -130,18 +140,6 @@ class AiohttpClientMockResponse: cookie.value = data self._cookies[name] = cookie - if isinstance(response, list): - self.content = mock.MagicMock() - - @asyncio.coroutine - def read(*argc, **kwargs): - """Read content stream mock.""" - if self.response: - return self.response.pop() - return None - - self.content.read = read - def match_request(self, method, url, params=None): """Test if response answers request.""" if method.lower() != self.method.lower(): @@ -177,6 +175,11 @@ class AiohttpClientMockResponse: """Return dict of cookies.""" return self._cookies + @property + def content(self): + """Return content.""" + return mock_stream(self.response) + @asyncio.coroutine def read(self): """Return mock response."""