Shield async httpx call in generic (#47852)

* Shield async httpx call

* Don't set last_url/last_image on cancellation

* Add test
This commit is contained in:
uvjustin 2021-03-31 12:46:10 +08:00 committed by GitHub
parent 7a6c88feeb
commit 379843eb54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 9 deletions

View File

@ -125,32 +125,45 @@ class GenericCamera(Camera):
).result() ).result()
async def async_camera_image(self): async def async_camera_image(self):
"""Wrap _async_camera_image with an asyncio.shield."""
# Shield the request because of https://github.com/encode/httpx/issues/1461
try:
self._last_url, self._last_image = await asyncio.shield(
self._async_camera_image()
)
except asyncio.CancelledError as err:
_LOGGER.warning("Timeout getting camera image from %s", self._name)
raise err
return self._last_image
async def _async_camera_image(self):
"""Return a still image response from the camera.""" """Return a still image response from the camera."""
try: try:
url = self._still_image_url.async_render(parse_result=False) url = self._still_image_url.async_render(parse_result=False)
except TemplateError as err: except TemplateError as err:
_LOGGER.error("Error parsing template %s: %s", self._still_image_url, err) _LOGGER.error("Error parsing template %s: %s", self._still_image_url, err)
return self._last_image return self._last_url, self._last_image
if url == self._last_url and self._limit_refetch: if url == self._last_url and self._limit_refetch:
return self._last_image return self._last_url, self._last_image
response = None
try: try:
async_client = get_async_client(self.hass, verify_ssl=self.verify_ssl) async_client = get_async_client(self.hass, verify_ssl=self.verify_ssl)
response = await async_client.get( response = await async_client.get(
url, auth=self._auth, timeout=GET_IMAGE_TIMEOUT url, auth=self._auth, timeout=GET_IMAGE_TIMEOUT
) )
response.raise_for_status() response.raise_for_status()
self._last_image = response.content image = response.content
except httpx.TimeoutException: except httpx.TimeoutException:
_LOGGER.error("Timeout getting camera image from %s", self._name) _LOGGER.error("Timeout getting camera image from %s", self._name)
return self._last_image return self._last_url, self._last_image
except (httpx.RequestError, httpx.HTTPStatusError) as err: except (httpx.RequestError, httpx.HTTPStatusError) as err:
_LOGGER.error("Error getting new camera image from %s: %s", self._name, err) _LOGGER.error("Error getting new camera image from %s: %s", self._name, err)
return self._last_image return self._last_url, self._last_image
finally:
self._last_url = url if response:
return self._last_image await response.aclose()
return url, image
@property @property
def name(self): def name(self):

View File

@ -3,6 +3,7 @@ import asyncio
from os import path from os import path
from unittest.mock import patch from unittest.mock import patch
import httpx
import respx import respx
from homeassistant import config as hass_config from homeassistant import config as hass_config
@ -407,5 +408,56 @@ async def test_reloading(hass, hass_client):
assert body == "hello world" assert body == "hello world"
@respx.mock
async def test_timeout_cancelled(hass, hass_client):
"""Test that timeouts and cancellations return last image."""
respx.get("http://example.com").respond(text="hello world")
await async_setup_component(
hass,
"camera",
{
"camera": {
"name": "config_test",
"platform": "generic",
"still_image_url": "http://example.com",
"username": "user",
"password": "pass",
}
},
)
await hass.async_block_till_done()
client = await hass_client()
resp = await client.get("/api/camera_proxy/camera.config_test")
assert resp.status == 200
assert respx.calls.call_count == 1
assert await resp.text() == "hello world"
respx.get("http://example.com").respond(text="not hello world")
with patch(
"homeassistant.components.generic.camera.GenericCamera._async_camera_image",
side_effect=asyncio.CancelledError(),
):
resp = await client.get("/api/camera_proxy/camera.config_test")
assert respx.calls.call_count == 1
assert resp.status == 500
respx.get("http://example.com").side_effect = [
httpx.RequestError,
httpx.TimeoutException,
]
for total_calls in range(2, 4):
resp = await client.get("/api/camera_proxy/camera.config_test")
assert respx.calls.call_count == total_calls
assert resp.status == 200
assert await resp.text() == "hello world"
def _get_fixtures_base_path(): def _get_fixtures_base_path():
return path.dirname(path.dirname(path.dirname(__file__))) return path.dirname(path.dirname(path.dirname(__file__)))