Stream API requests to the supervisor (#53909)

This commit is contained in:
Joakim Sørensen 2021-08-03 16:48:22 +02:00 committed by GitHub
parent 2105419a4e
commit 56360feb9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 49 additions and 39 deletions

View File

@ -8,9 +8,14 @@ import re
import aiohttp import aiohttp
from aiohttp import web from aiohttp import web
from aiohttp.hdrs import CONTENT_LENGTH, CONTENT_TYPE from aiohttp.client import ClientTimeout
from aiohttp.hdrs import (
CONTENT_ENCODING,
CONTENT_LENGTH,
CONTENT_TYPE,
TRANSFER_ENCODING,
)
from aiohttp.web_exceptions import HTTPBadGateway from aiohttp.web_exceptions import HTTPBadGateway
import async_timeout
from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView from homeassistant.components.http import KEY_AUTHENTICATED, HomeAssistantView
from homeassistant.components.onboarding import async_is_onboarded from homeassistant.components.onboarding import async_is_onboarded
@ -75,14 +80,11 @@ class HassIOView(HomeAssistantView):
async def _command_proxy( async def _command_proxy(
self, path: str, request: web.Request self, path: str, request: web.Request
) -> web.Response | web.StreamResponse: ) -> web.StreamResponse:
"""Return a client request with proxy origin for Hass.io supervisor. """Return a client request with proxy origin for Hass.io supervisor.
This method is a coroutine. This method is a coroutine.
""" """
read_timeout = _get_timeout(path)
client_timeout = 10
data = None
headers = _init_header(request) headers = _init_header(request)
if path in ("snapshots/new/upload", "backups/new/upload"): if path in ("snapshots/new/upload", "backups/new/upload"):
# We need to reuse the full content type that includes the boundary # We need to reuse the full content type that includes the boundary
@ -90,34 +92,20 @@ class HassIOView(HomeAssistantView):
"Content-Type" "Content-Type"
] = request._stored_content_type # pylint: disable=protected-access ] = request._stored_content_type # pylint: disable=protected-access
# Backups are big, so we need to adjust the allowed size
request._client_max_size = ( # pylint: disable=protected-access
MAX_UPLOAD_SIZE
)
client_timeout = 300
try: try:
with async_timeout.timeout(client_timeout): client = await self._websession.request(
data = await request.read() method=request.method,
url=f"http://{self._host}/{path}",
method = getattr(self._websession, request.method.lower()) params=request.query,
client = await method( data=request.content,
f"http://{self._host}/{path}",
data=data,
headers=headers, headers=headers,
timeout=read_timeout, timeout=_get_timeout(path),
) )
# Simple request
if int(client.headers.get(CONTENT_LENGTH, 0)) < 4194000:
# Return Response
body = await client.read()
return web.Response(
content_type=client.content_type, status=client.status, body=body
)
# Stream response # Stream response
response = web.StreamResponse(status=client.status, headers=client.headers) response = web.StreamResponse(
status=client.status, headers=_response_header(client)
)
response.content_type = client.content_type response.content_type = client.content_type
await response.prepare(request) await response.prepare(request)
@ -151,11 +139,28 @@ def _init_header(request: web.Request) -> dict[str, str]:
return headers return headers
def _get_timeout(path: str) -> int: def _response_header(response: aiohttp.ClientResponse) -> dict[str, str]:
"""Create response header."""
headers = {}
for name, value in response.headers.items():
if name in (
TRANSFER_ENCODING,
CONTENT_LENGTH,
CONTENT_TYPE,
CONTENT_ENCODING,
):
continue
headers[name] = value
return headers
def _get_timeout(path: str) -> ClientTimeout:
"""Return timeout for a URL path.""" """Return timeout for a URL path."""
if NO_TIMEOUT.match(path): if NO_TIMEOUT.match(path):
return 0 return ClientTimeout(connect=10, total=None)
return 300 return ClientTimeout(connect=10, total=300)
def _need_auth(hass, path: str) -> bool: def _need_auth(hass, path: str) -> bool:

View File

@ -1,7 +1,7 @@
"""The tests for the hassio component.""" """The tests for the hassio component."""
import asyncio import asyncio
from unittest.mock import patch
from aiohttp import StreamReader
import pytest import pytest
from homeassistant.components.hassio.http import _need_auth from homeassistant.components.hassio.http import _need_auth
@ -106,13 +106,11 @@ async def test_forward_log_request(hassio_client, aioclient_mock):
assert len(aioclient_mock.mock_calls) == 1 assert len(aioclient_mock.mock_calls) == 1
async def test_bad_gateway_when_cannot_find_supervisor(hassio_client): async def test_bad_gateway_when_cannot_find_supervisor(hassio_client, aioclient_mock):
"""Test we get a bad gateway error if we can't find supervisor.""" """Test we get a bad gateway error if we can't find supervisor."""
with patch( aioclient_mock.get("http://127.0.0.1/addons/test/info", exc=asyncio.TimeoutError)
"homeassistant.components.hassio.http.async_timeout.timeout",
side_effect=asyncio.TimeoutError, resp = await hassio_client.get("/api/hassio/addons/test/info")
):
resp = await hassio_client.get("/api/hassio/addons/test/info")
assert resp.status == 502 assert resp.status == 502
@ -180,3 +178,10 @@ def test_need_auth(hass):
hass.data["onboarding"] = False hass.data["onboarding"] = False
assert not _need_auth(hass, "backups/new/upload") assert not _need_auth(hass, "backups/new/upload")
assert not _need_auth(hass, "supervisor/logs") assert not _need_auth(hass, "supervisor/logs")
async def test_stream(hassio_client, aioclient_mock):
"""Verify that the request is a stream."""
aioclient_mock.get("http://127.0.0.1/test")
await hassio_client.get("/api/hassio/test", data="test")
assert isinstance(aioclient_mock.mock_calls[-1][2], StreamReader)