Allow uploading large snapshots (#53528)

Co-authored-by: Pascal Vizeli <pascal.vizeli@syshack.ch>
This commit is contained in:
Stephen Beechen 2021-07-28 23:12:59 -06:00 committed by GitHub
parent e14a04df2e
commit cdce14d63d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 53 deletions

View File

@ -1,16 +1,15 @@
"""HTTP Support for Hass.io."""
from __future__ import annotations
import asyncio
import logging
import os
import re
import aiohttp
from aiohttp import web
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE
from aiohttp.client import ClientError, ClientTimeout
from aiohttp.hdrs import CONTENT_TYPE
from aiohttp.web_exceptions import HTTPBadGateway
import async_timeout
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
from homeassistant.components.onboarding import async_is_onboarded
@ -20,8 +19,6 @@ from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID, X_HASSIO
_LOGGER = logging.getLogger(__name__)
MAX_UPLOAD_SIZE = 1024 * 1024 * 1024
NO_TIMEOUT = re.compile(
r"^(?:"
r"|homeassistant/update"
@ -75,48 +72,28 @@ class HassIOView(HomeAssistantView):
async def _command_proxy(
self, path: str, request: web.Request
) -> web.Response | web.StreamResponse:
) -> web.StreamResponse:
"""Return a client request with proxy origin for Hass.io supervisor.
This method is a coroutine.
"""
read_timeout = _get_timeout(path)
client_timeout = 10
data = None
headers = _init_header(request)
if path in ("snapshots/new/upload", "backups/new/upload"):
# We need to reuse the full content type that includes the boundary
headers[
"Content-Type"
] = request._stored_content_type # pylint: disable=protected-access
# Backups are big, so we need to adjust the allowed size
request._client_max_size = ( # pylint: disable=protected-access
MAX_UPLOAD_SIZE
)
client_timeout = 300
try:
with async_timeout.timeout(client_timeout):
data = await request.read()
method = getattr(self._websession, request.method.lower())
client = await method(
f"http://{self._host}/{path}",
data=data,
# Stream the request to the supervisor
client = await self._websession.request(
method=request.method,
url=f"http://{self._host}/{path}",
headers=headers,
timeout=read_timeout,
data=request.content,
timeout=_get_timeout(path),
)
# Simple request
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
# Return Response
body = await client.read()
return web.Response(
content_type=client.content_type, status=client.status, body=body
)
# Stream response
# Stream the supervisor response back
response = web.StreamResponse(status=client.status, headers=client.headers)
response.content_type = client.content_type
@ -126,12 +103,9 @@ class HassIOView(HomeAssistantView):
return response
except aiohttp.ClientError as err:
except ClientError as err:
_LOGGER.error("Client error on api %s request %s", path, err)
except asyncio.TimeoutError:
_LOGGER.error("Client timeout error on API request %s", path)
raise HTTPBadGateway()
@ -151,11 +125,11 @@ def _init_header(request: web.Request) -> dict[str, str]:
return headers
def _get_timeout(path: str) -> int:
def _get_timeout(path: str) -> ClientTimeout:
"""Return timeout for a URL path."""
if NO_TIMEOUT.match(path):
return 0
return 300
return ClientTimeout(connect=10)
return ClientTimeout(connect=10, total=300)
def _need_auth(hass, path: str) -> bool:

View File

@ -1,11 +1,13 @@
"""The tests for the hassio component."""
import asyncio
from unittest.mock import patch
from aiohttp.client import ClientError
from aiohttp.streams import StreamReader
from aiohttp.test_utils import TestClient
import pytest
from homeassistant.components.hassio.http import _need_auth
from tests.test_util.aiohttp import AiohttpClientMocker
async def test_forward_request(hassio_client, aioclient_mock):
"""Test fetching normal path."""
@ -106,16 +108,6 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
assert len(aioclient_mock.mock_calls) == 1
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client):
"""Test we get a bad gateway error if we can't find supervisor."""
with patch(
"homeassistant.components.hassio.http.async_timeout.timeout",
side_effect=asyncio.TimeoutError,
):
resp = await hassio_client.get("/api/hassio/addons/test/info")
assert resp.status == 502
async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock):
"""Test that we forward user info correctly."""
aioclient_mock.get("http://127.0.0.1/hello")
@ -171,6 +163,37 @@ async def test_backup_download_headers(hassio_client, aioclient_mock):
assert resp.headers["Content-Disposition"] == content_disposition
async def test_supervisor_client_error(
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
):
"""Test any client error from the supervisor returns a 502."""
# Create a request that throws a ClientError
async def raise_client_error(*args):
raise ClientError()
aioclient_mock.get(
"http://127.0.0.1/test/raise/error",
side_effect=raise_client_error,
)
# Verify it returns bad gateway
resp = await hassio_client.get("/api/hassio/test/raise/error")
assert resp.status == 502
assert len(aioclient_mock.mock_calls) == 1
async def test_streamed_requests(
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
):
"""Test requests get proxied to the supervisor as a stream."""
aioclient_mock.get("http://127.0.0.1/test/stream")
await hassio_client.get("/api/hassio/test/stream", data="Test data")
assert len(aioclient_mock.mock_calls) == 1
# Verify the request body is passed as a StreamReader
assert isinstance(aioclient_mock.mock_calls[0][2], StreamReader)
def test_need_auth(hass):
"""Test if the requested path needs authentication."""
assert not _need_auth(hass, "addons/test/logo")