mirror of
https://github.com/home-assistant/core.git
synced 2025-07-14 00:37:13 +00:00
Allow uploading large snapshots (#53528)
Co-authored-by: Pascal Vizeli <pascal.vizeli@syshack.ch>
This commit is contained in:
parent
e14a04df2e
commit
cdce14d63d
@ -1,16 +1,15 @@
|
|||||||
"""HTTP Support for Hass.io."""
|
"""HTTP Support for Hass.io."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE
|
from aiohttp.client import ClientError, ClientTimeout
|
||||||
|
from aiohttp.hdrs import CONTENT_TYPE
|
||||||
from aiohttp.web_exceptions import HTTPBadGateway
|
from aiohttp.web_exceptions import HTTPBadGateway
|
||||||
import async_timeout
|
|
||||||
|
|
||||||
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
|
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
|
||||||
from homeassistant.components.onboarding import async_is_onboarded
|
from homeassistant.components.onboarding import async_is_onboarded
|
||||||
@ -20,8 +19,6 @@ from .const import X_HASS_IS_ADMIN, X_HASS_USER_ID, X_HASSIO
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_UPLOAD_SIZE = 1024 * 1024 * 1024
|
|
||||||
|
|
||||||
NO_TIMEOUT = re.compile(
|
NO_TIMEOUT = re.compile(
|
||||||
r"^(?:"
|
r"^(?:"
|
||||||
r"|homeassistant/update"
|
r"|homeassistant/update"
|
||||||
@ -75,48 +72,28 @@ class HassIOView(HomeAssistantView):
|
|||||||
|
|
||||||
async def _command_proxy(
|
async def _command_proxy(
|
||||||
self, path: str, request: web.Request
|
self, path: str, request: web.Request
|
||||||
) -> web.Response | web.StreamResponse:
|
) -> web.StreamResponse:
|
||||||
"""Return a client request with proxy origin for Hass.io supervisor.
|
"""Return a client request with proxy origin for Hass.io supervisor.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
read_timeout = _get_timeout(path)
|
|
||||||
client_timeout = 10
|
|
||||||
data = None
|
|
||||||
headers = _init_header(request)
|
headers = _init_header(request)
|
||||||
if path in ("snapshots/new/upload", "backups/new/upload"):
|
if path in ("snapshots/new/upload", "backups/new/upload"):
|
||||||
# We need to reuse the full content type that includes the boundary
|
# We need to reuse the full content type that includes the boundary
|
||||||
headers[
|
headers[
|
||||||
"Content-Type"
|
"Content-Type"
|
||||||
] = request._stored_content_type # pylint: disable=protected-access
|
] = request._stored_content_type # pylint: disable=protected-access
|
||||||
|
|
||||||
# Backups are big, so we need to adjust the allowed size
|
|
||||||
request._client_max_size = ( # pylint: disable=protected-access
|
|
||||||
MAX_UPLOAD_SIZE
|
|
||||||
)
|
|
||||||
client_timeout = 300
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with async_timeout.timeout(client_timeout):
|
# Stream the request to the supervisor
|
||||||
data = await request.read()
|
client = await self._websession.request(
|
||||||
|
method=request.method,
|
||||||
method = getattr(self._websession, request.method.lower())
|
url=f"http://{self._host}/{path}",
|
||||||
client = await method(
|
|
||||||
f"http://{self._host}/{path}",
|
|
||||||
data=data,
|
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=read_timeout,
|
data=request.content,
|
||||||
|
timeout=_get_timeout(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simple request
|
# Stream the supervisor response back
|
||||||
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
|
|
||||||
# Return Response
|
|
||||||
body = await client.read()
|
|
||||||
return web.Response(
|
|
||||||
content_type=client.content_type, status=client.status, body=body
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stream response
|
|
||||||
response = web.StreamResponse(status=client.status, headers=client.headers)
|
response = web.StreamResponse(status=client.status, headers=client.headers)
|
||||||
response.content_type = client.content_type
|
response.content_type = client.content_type
|
||||||
|
|
||||||
@ -126,12 +103,9 @@ class HassIOView(HomeAssistantView):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except aiohttp.ClientError as err:
|
except ClientError as err:
|
||||||
_LOGGER.error("Client error on api %s request %s", path, err)
|
_LOGGER.error("Client error on api %s request %s", path, err)
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
_LOGGER.error("Client timeout error on API request %s", path)
|
|
||||||
|
|
||||||
raise HTTPBadGateway()
|
raise HTTPBadGateway()
|
||||||
|
|
||||||
|
|
||||||
@ -151,11 +125,11 @@ def _init_header(request: web.Request) -> dict[str, str]:
|
|||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def _get_timeout(path: str) -> int:
|
def _get_timeout(path: str) -> ClientTimeout:
|
||||||
"""Return timeout for a URL path."""
|
"""Return timeout for a URL path."""
|
||||||
if NO_TIMEOUT.match(path):
|
if NO_TIMEOUT.match(path):
|
||||||
return 0
|
return ClientTimeout(connect=10)
|
||||||
return 300
|
return ClientTimeout(connect=10, total=300)
|
||||||
|
|
||||||
|
|
||||||
def _need_auth(hass, path: str) -> bool:
|
def _need_auth(hass, path: str) -> bool:
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
"""The tests for the hassio component."""
|
"""The tests for the hassio component."""
|
||||||
import asyncio
|
from aiohttp.client import ClientError
|
||||||
from unittest.mock import patch
|
from aiohttp.streams import StreamReader
|
||||||
|
from aiohttp.test_utils import TestClient
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.hassio.http import _need_auth
|
from homeassistant.components.hassio.http import _need_auth
|
||||||
|
|
||||||
|
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||||
|
|
||||||
|
|
||||||
async def test_forward_request(hassio_client, aioclient_mock):
|
async def test_forward_request(hassio_client, aioclient_mock):
|
||||||
"""Test fetching normal path."""
|
"""Test fetching normal path."""
|
||||||
@ -106,16 +108,6 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
|
|||||||
assert len(aioclient_mock.mock_calls) == 1
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client):
|
|
||||||
"""Test we get a bad gateway error if we can't find supervisor."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.hassio.http.async_timeout.timeout",
|
|
||||||
side_effect=asyncio.TimeoutError,
|
|
||||||
):
|
|
||||||
resp = await hassio_client.get("/api/hassio/addons/test/info")
|
|
||||||
assert resp.status == 502
|
|
||||||
|
|
||||||
|
|
||||||
async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock):
|
async def test_forwarding_user_info(hassio_client, hass_admin_user, aioclient_mock):
|
||||||
"""Test that we forward user info correctly."""
|
"""Test that we forward user info correctly."""
|
||||||
aioclient_mock.get("http://127.0.0.1/hello")
|
aioclient_mock.get("http://127.0.0.1/hello")
|
||||||
@ -171,6 +163,37 @@ async def test_backup_download_headers(hassio_client, aioclient_mock):
|
|||||||
assert resp.headers["Content-Disposition"] == content_disposition
|
assert resp.headers["Content-Disposition"] == content_disposition
|
||||||
|
|
||||||
|
|
||||||
|
async def test_supervisor_client_error(
|
||||||
|
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
|
||||||
|
):
|
||||||
|
"""Test any client error from the supervisor returns a 502."""
|
||||||
|
# Create a request that throws a ClientError
|
||||||
|
async def raise_client_error(*args):
|
||||||
|
raise ClientError()
|
||||||
|
|
||||||
|
aioclient_mock.get(
|
||||||
|
"http://127.0.0.1/test/raise/error",
|
||||||
|
side_effect=raise_client_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify it returns bad gateway
|
||||||
|
resp = await hassio_client.get("/api/hassio/test/raise/error")
|
||||||
|
assert resp.status == 502
|
||||||
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_streamed_requests(
|
||||||
|
hassio_client: TestClient, aioclient_mock: AiohttpClientMocker
|
||||||
|
):
|
||||||
|
"""Test requests get proxied to the supervisor as a stream."""
|
||||||
|
aioclient_mock.get("http://127.0.0.1/test/stream")
|
||||||
|
await hassio_client.get("/api/hassio/test/stream", data="Test data")
|
||||||
|
assert len(aioclient_mock.mock_calls) == 1
|
||||||
|
|
||||||
|
# Verify the request body is passed as a StreamReader
|
||||||
|
assert isinstance(aioclient_mock.mock_calls[0][2], StreamReader)
|
||||||
|
|
||||||
|
|
||||||
def test_need_auth(hass):
|
def test_need_auth(hass):
|
||||||
"""Test if the requested path needs authentication."""
|
"""Test if the requested path needs authentication."""
|
||||||
assert not _need_auth(hass, "addons/test/logo")
|
assert not _need_auth(hass, "addons/test/logo")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user