diff --git a/homeassistant/components/hassio/http.py b/homeassistant/components/hassio/http.py index 302cc00bb9f..73e5549be9a 100644 --- a/homeassistant/components/hassio/http.py +++ b/homeassistant/components/hassio/http.py @@ -1,16 +1,15 @@ """HTTP Support for Hass.io.""" from __future__ import annotations -import asyncio import logging import os import re import aiohttp from aiohttp import web -from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE +from aiohttp.client import ClientError, ClientTimeout +from aiohttp.hdrs import CONTENT_TYPE from aiohttp.web_exceptions import HTTPBadGateway -import async_timeout from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView from homeassistant.components.onboarding import async_is_onboarded @@ -20,8 +19,6 @@ from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID, X_HASSIO _LOGGER = logging.getLogger(__name__) -MAX_UPLOAD_SIZE = 1024 * 1024 * 1024 - NO_TIMEOUT = re.compile( r"^(?:" r"|homeassistant/update" @@ -75,48 +72,28 @@ class HassIOView(HomeAssistantView): async def _command_proxy( self, path: str, request: web.Request - ) -> web.Response | web.StreamResponse: + ) -> web.StreamResponse: """Return a client request with proxy origin for Hass.io supervisor. This method is a coroutine. """ - read_timeout = _get_timeout(path) - client_timeout = 10 - data = None headers = _init_header(request) if path in ("snapshots/new/upload", "backups/new/upload"): # We need to reuse the full content type that includes the boundary headers[ "Content-Type" ] = request._stored_content_type # pylint: disable=protected-access - - # Backups are big, so we need to adjust the allowed size - request._client_max_size = ( # pylint: disable=protected-access - MAX_UPLOAD_SIZE - ) - client_timeout = 300 - try: - with async_timeout.timeout(client_timeout): - data = await request.read() - - method = getattr(self._websession, request.method.lower()) - client = await method( - f"http://{self._host}/{path}", - data=data, + # Stream the request to the supervisor + client = await self._websession.request( + method=request.method, + url=f"http://{self._host}/{path}", headers=headers, - timeout=read_timeout, + data=request.content, + timeout=_get_timeout(path), ) - # Simple request - if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000: - # Return Response - body = await client.read() - return web.Response( - content_type=client.content_type, status=client.status, body=body - ) - - # Stream response + # Stream the supervisor response back response = web.StreamResponse(status=client.status, headers=client.headers) response.content_type = client.content_type @@ -126,12 +103,9 @@ class HassIOView(HomeAssistantView): return response - except aiohttp.ClientError as err: + except ClientError as err: _LOGGER.error("Client error on api %s request %s", path, err) - except asyncio.TimeoutError: - _LOGGER.error("Client timeout error on API request %s", path) - raise HTTPBadGateway() @@ -151,11 +125,11 @@ def _init_header(request: web.Request) -> dict[str, str]: return headers -def _get_timeout(path: str) -> int: +def _get_timeout(path: str) -> ClientTimeout: """Return timeout for a URL path.""" if NO_TIMEOUT.match(path): - return 0 - return 300 + return ClientTimeout(connect=10) + return ClientTimeout(connect=10, total=300) def _need_auth(hass, path: str) -> bool: diff --git a/tests/components/hassio/test_http.py b/tests/components/hassio/test_http.py index fc4bb3e6a0d..881d3cc26ed 100644 --- a/tests/components/hassio/test_http.py +++ b/tests/components/hassio/test_http.py @@ -1,11 +1,13 @@ """The tests for the hassio component.""" -import asyncio -from unittest.mock import patch - +from aiohttp.client import ClientError +from aiohttp.streams import StreamReader +from aiohttp.test_utils import TestClient import pytest from homeassistant.components.hassio.http import _need_auth +from tests.test_util.aiohttp import AiohttpClientMocker + async def test_forward_request(hassio_client, aioclient_mock): """Test fetching normal path.""" @@ -106,16 +108,6 @@ async def test_forward_log_request(hassio_client, aioclient_mock): assert len(aioclient_mock.mock_calls) == 1 -async def test_bad_gateway_when_cannot_find_supervisor(hassio_client): - """Test we get a bad gateway error if we can't find supervisor.""" - with patch( - "homeassistant.components.hassio.http.async_timeout.timeout", - side_effect=asyncio.TimeoutError, - ): - resp = await hassio_client.get("/api/hassio/addons/test/info") - assert resp.status == 502 - - async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock): """Test that we forward user info correctly.""" aioclient_mock.get("http://127.0.0.1/hello") @@ -171,6 +163,37 @@ async def test_backup_download_headers(hassio_client, aioclient_mock): assert resp.headers["Content-Disposition"] == content_disposition +async def test_supervisor_client_error( + hassio_client: TestClient, aioclient_mock: AiohttpClientMocker +): + """Test any client error from the supervisor returns a 502.""" + # Create a request that throws a ClientError + async def raise_client_error(*args): + raise ClientError() + + aioclient_mock.get( + "http://127.0.0.1/test/raise/error", + side_effect=raise_client_error, + ) + + # Verify it returns bad gateway + resp = await hassio_client.get("/api/hassio/test/raise/error") + assert resp.status == 502 + assert len(aioclient_mock.mock_calls) == 1 + + +async def test_streamed_requests( + hassio_client: TestClient, aioclient_mock: AiohttpClientMocker +): + """Test requests get proxied to the supervisor as a stream.""" + aioclient_mock.get("http://127.0.0.1/test/stream") + await hassio_client.get("/api/hassio/test/stream", data="Test data") + assert len(aioclient_mock.mock_calls) == 1 + + # Verify the request body is passed as a StreamReader + assert isinstance(aioclient_mock.mock_calls[0][2], StreamReader) + + def test_need_auth(hass): """Test if the requested path needs authentication.""" assert not _need_auth(hass, "addons/test/logo")