diff --git a/homeassistant/components/auth/indieauth.py b/homeassistant/components/auth/indieauth.py index 48f7ab06ab4..bcf73258ffa 100644 --- a/homeassistant/components/auth/indieauth.py +++ b/homeassistant/components/auth/indieauth.py @@ -4,6 +4,7 @@ from html.parser import HTMLParser from ipaddress import ip_address, ip_network from urllib.parse import urlparse, urljoin +import aiohttp from aiohttp.client_exceptions import ClientError # IP addresses of loopback interfaces @@ -76,18 +77,17 @@ async def fetch_redirect_uris(hass, url): We do not implement extracting redirect uris from headers. """ - session = hass.helpers.aiohttp_client.async_get_clientsession() parser = LinkTagParser('redirect_uri') chunks = 0 try: - resp = await session.get(url, timeout=5) + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=5) as resp: + async for data in resp.content.iter_chunked(1024): + parser.feed(data.decode()) + chunks += 1 - async for data in resp.content.iter_chunked(1024): - parser.feed(data.decode()) - chunks += 1 - - if chunks == 10: - break + if chunks == 10: + break except (asyncio.TimeoutError, ClientError): pass diff --git a/tests/components/auth/test_indieauth.py b/tests/components/auth/test_indieauth.py index 75e61af2e71..d30ead10cb2 100644 --- a/tests/components/auth/test_indieauth.py +++ b/tests/components/auth/test_indieauth.py @@ -1,4 +1,5 @@ """Tests for the client validator.""" +import asyncio from unittest.mock import patch import pytest @@ -6,6 +7,18 @@ import pytest from homeassistant.components.auth import indieauth from tests.common import mock_coro +from tests.test_util.aiohttp import AiohttpClientMocker + + +@pytest.fixture +def mock_session(): + """Mock aiohttp.ClientSession.""" + mocker = AiohttpClientMocker() + + with patch('aiohttp.ClientSession', + side_effect=lambda *args, **kwargs: + mocker.create_session(asyncio.get_event_loop())): + yield mocker def test_client_id_scheme(): @@ -120,9 +133,9 @@ async def test_verify_redirect_uri(): ) -async def test_find_link_tag(hass, aioclient_mock): +async def test_find_link_tag(hass, mock_session): """Test finding link tag.""" - aioclient_mock.get("http://127.0.0.1:8000", text=""" + mock_session.get("http://127.0.0.1:8000", text=""" @@ -142,11 +155,15 @@ async def test_find_link_tag(hass, aioclient_mock): ] -async def test_find_link_tag_max_size(hass, aioclient_mock): +async def test_find_link_tag_max_size(hass, mock_session): """Test finding link tag.""" - text = ("0" * 1024 * 10) + '' - aioclient_mock.get("http://127.0.0.1:8000", text=text) + text = ''.join([ + '', + ("0" * 1024 * 10), + '', + ]) + mock_session.get("http://127.0.0.1:8000", text=text) redirect_uris = await indieauth.fetch_redirect_uris( hass, "http://127.0.0.1:8000") - assert redirect_uris == [] + assert redirect_uris == ["http://127.0.0.1:8000/wine"]