Allow uploading large snapshots (#53528)

Co-authored-by: Pascal Vizeli <pascal.vizeli@syshack.ch>
This commit is contained in:
Stephen Beechen 2021-07-28 23:12:59 -06:00 committed by GitHub
parent e14a04df2e
commit cdce14d63d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 53 deletions

View File

@ -1,16 +1,15 @@
"""HTTP Support for Hass.io.""" """HTTP Support for Hass.io."""
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import os import os
import re import re
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE from aiohttp.client import ClientError, ClientTimeout
from aiohttp.hdrs import CONTENT_TYPE
from aiohttp.web_exceptions import HTTPBadGateway from aiohttp.web_exceptions import HTTPBadGateway
import async_timeout
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
from homeassistant.components.onboarding import async_is_onboarded from homeassistant.components.onboarding import async_is_onboarded
@ -20,8 +19,6 @@ from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID, X_HASSIO
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
MAX_UPLOAD_SIZE = 1024 * 1024 * 1024
NO_TIMEOUT = re.compile( NO_TIMEOUT = re.compile(
r"^(?:" r"^(?:"
r"|homeassistant/update" r"|homeassistant/update"
@ -75,48 +72,28 @@ class HassIOView(HomeAssistantView):
async def _command_proxy( async def _command_proxy(
self, path: str, request: web.Request self, path: str, request: web.Request
) -> web.Response | web.StreamResponse: ) -> web.StreamResponse:
"""Return a client request with proxy origin for Hass.io supervisor. """Return a client request with proxy origin for Hass.io supervisor.
This method is a coroutine. This method is a coroutine.
""" """
read_timeout = _get_timeout(path)
client_timeout = 10
data = None
headers = _init_header(request) headers = _init_header(request)
if path in ("snapshots/new/upload", "backups/new/upload"): if path in ("snapshots/new/upload", "backups/new/upload"):
# We need to reuse the full content type that includes the boundary # We need to reuse the full content type that includes the boundary
headers[ headers[
"Content-Type" "Content-Type"
] = request._stored_content_type # pylint: disable=protected-access ] = request._stored_content_type # pylint: disable=protected-access
# Backups are big, so we need to adjust the allowed size
request._client_max_size = ( # pylint: disable=protected-access
MAX_UPLOAD_SIZE
)
client_timeout = 300
try: try:
with async_timeout.timeout(client_timeout): # Stream the request to the supervisor
data = await request.read() client = await self._websession.request(
method=request.method,
method = getattr(self._websession, request.method.lower()) url=f"http://{self._host}/{path}",
client = await method(
f"http://{self._host}/{path}",
data=data,
headers=headers, headers=headers,
timeout=read_timeout, data=request.content,
timeout=_get_timeout(path),
) )
# Simple request # Stream the supervisor response back
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
# Return Response
body = await client.read()
return web.Response(
content_type=client.content_type, status=client.status, body=body
)
# Stream response
response = web.StreamResponse(status=client.status, headers=client.headers) response = web.StreamResponse(status=client.status, headers=client.headers)
response.content_type = client.content_type response.content_type = client.content_type
@ -126,12 +103,9 @@ class HassIOView(HomeAssistantView):
return response return response
except aiohttp.ClientError as err: except ClientError as err:
_LOGGER.error("Client error on api %s request %s", path, err) _LOGGER.error("Client error on api %s request %s", path, err)
except asyncio.TimeoutError:
_LOGGER.error("Client timeout error on API request %s", path)
raise HTTPBadGateway() raise HTTPBadGateway()
@ -151,11 +125,11 @@ def _init_header(request: web.Request) -> dict[str, str]:
return headers return headers
def _get_timeout(path: str) -> int: def _get_timeout(path: str) -> ClientTimeout:
"""Return timeout for a URL path.""" """Return timeout for a URL path."""
if NO_TIMEOUT.match(path): if NO_TIMEOUT.match(path):
return 0 return ClientTimeout(connect=10)
return 300 return ClientTimeout(connect=10, total=300)
def _need_auth(hass, path: str) -> bool: def _need_auth(hass, path: str) -> bool:

View File

@ -1,11 +1,13 @@
"""The tests for the hassio component.""" """The tests for the hassio component."""
import asyncio from aiohttp.client import ClientError
from unittest.mock import patch from aiohttp.streams import StreamReader
from aiohttp.test_utils import TestClient
import pytest import pytest
from homeassistant.components.hassio.http import _need_auth from homeassistant.components.hassio.http import _need_auth
from tests.test_util.aiohttp import AiohttpClientMocker
async def test_forward_request(hassio_client, aioclient_mock): async def test_forward_request(hassio_client, aioclient_mock):
"""Test fetching normal path.""" """Test fetching normal path."""
@ -106,16 +108,6 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client):
"""Test we get a bad gateway error if we can't find supervisor."""
with patch(
"homeassistant.components.hassio.http.async_timeout.timeout",
side_effect=asyncio.TimeoutError,
):
resp = await hassio_client.get("/api/hassio/addons/test/info")
assert resp.status == 502
async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock): async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock):
"""Test that we forward user info correctly.""" """Test that we forward user info correctly."""
aioclient_mock.get("http://127.0.0.1/hello") aioclient_mock.get("http://127.0.0.1/hello")
@ -171,6 +163,37 @@ async def test_backup_download_headers(hassio_client, aioclient_mock):
assert resp.headers["Content-Disposition"] == content_disposition assert resp.headers["Content-Disposition"] == content_disposition
async def test_supervisor_client_error(
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
):
"""Test any client error from the supervisor returns a 502."""
# Create a request that throws a ClientError
async def raise_client_error(*args):
raise ClientError()
aioclient_mock.get(
"http://127.0.0.1/test/raise/error",
side_effect=raise_client_error,
)
# Verify it returns bad gateway
resp = await hassio_client.get("/api/hassio/test/raise/error")
assert resp.status == 502
assert len(aioclient_mock.mock_calls) == 1
async def test_streamed_requests(
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
):
"""Test requests get proxied to the supervisor as a stream."""
aioclient_mock.get("http://127.0.0.1/test/stream")
await hassio_client.get("/api/hassio/test/stream", data="Test data")
assert len(aioclient_mock.mock_calls) == 1
# Verify the request body is passed as a StreamReader
assert isinstance(aioclient_mock.mock_calls[0][2], StreamReader)
def test_need_auth(hass): def test_need_auth(hass):
"""Test if the requested path needs authentication.""" """Test if the requested path needs authentication."""
assert not _need_auth(hass, "addons/test/logo") assert not _need_auth(hass, "addons/test/logo")