Refactor url fetch code to use base platform

This commit is contained in:
jbouwh 2023-06-26 22:14:55 +00:00 committed by Erik
parent 51edc007fe
commit 5f3bcee97e
2 changed files with 21 additions and 29 deletions

View File

@ -6,7 +6,6 @@ import binascii
from collections.abc import Callable from collections.abc import Callable
import functools import functools
import logging import logging
import ssl
from typing import Any from typing import Any
import httpx import httpx
@ -106,6 +105,7 @@ class MqttImage(MqttEntity, ImageEntity):
_entity_id_format: str = image.ENTITY_ID_FORMAT _entity_id_format: str = image.ENTITY_ID_FORMAT
_last_image: bytes | None = None _last_image: bytes | None = None
_client: httpx.AsyncClient _client: httpx.AsyncClient
_url: str | None = None
_url_template: Callable[[ReceivePayloadType], ReceivePayloadType] _url_template: Callable[[ReceivePayloadType], ReceivePayloadType]
_topic: dict[str, Any] _topic: dict[str, Any]
@ -143,23 +143,6 @@ class MqttImage(MqttEntity, ImageEntity):
config.get(CONF_URL_TEMPLATE), entity=self config.get(CONF_URL_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
async def _async_load_image(self, url: str) -> None:
try:
response = await self._client.request(
"GET", url, timeout=GET_IMAGE_TIMEOUT, follow_redirects=True
)
except (httpx.TimeoutException, httpx.RequestError, ssl.SSLError) as ex:
_LOGGER.warning("Connection failed to url %s files: %s", url, ex)
self._last_image = None
self._attr_image_last_updated = dt_util.utcnow()
self.async_write_ha_state()
return
self._attr_content_type = response.headers["content-type"]
self._last_image = response.content
self._attr_image_last_updated = dt_util.utcnow()
self.async_write_ha_state()
def _prepare_subscribe_topics(self) -> None: def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@ -211,14 +194,16 @@ class MqttImage(MqttEntity, ImageEntity):
try: try:
url = cv.url(self._url_template(msg.payload)) url = cv.url(self._url_template(msg.payload))
self._url = url
except vol.Invalid: except vol.Invalid:
_LOGGER.error( _LOGGER.error(
"Invalid image URL '%s' received at topic %s", "Invalid image URL '%s' received at topic %s",
msg.payload, msg.payload,
msg.topic, msg.topic,
) )
return self._last_image = None
self.hass.async_create_task(self._async_load_image(url)) self._attr_image_last_updated = dt_util.utcnow()
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received) add_subscribe_topic(CONF_URL_TOPIC, image_from_url_request_received)
@ -232,4 +217,10 @@ class MqttImage(MqttEntity, ImageEntity):
async def async_image(self) -> bytes | None: async def async_image(self) -> bytes | None:
"""Return bytes of image.""" """Return bytes of image."""
return self._last_image if CONF_IMAGE_TOPIC in self._config:
return self._last_image
return await super().async_image()
async def async_image_url(self) -> str | None:
"""Return URL of image."""
return self._url

View File

@ -233,7 +233,7 @@ async def test_image_from_url(
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("image.test") state = hass.states.get("image.test")
assert state.state == STATE_UNKNOWN assert state.state == "2023-04-01T00:00:00+00:00"
assert "Invalid image URL" in caplog.text assert "Invalid image URL" in caplog.text
@ -356,7 +356,7 @@ async def test_image_from_url_content_type(
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("image.test") state = hass.states.get("image.test")
assert state.state == STATE_UNKNOWN assert state.state == "2023-04-01T00:00:00+00:00"
access_token = state.attributes["access_token"] access_token = state.attributes["access_token"]
assert state.attributes == { assert state.attributes == {
@ -397,11 +397,11 @@ async def test_image_from_url_content_type(
], ],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "log_text"), "side_effect",
[ [
(httpx.RequestError("server offline", request=MagicMock()), "server offline"), httpx.RequestError("server offline", request=MagicMock()),
(httpx.TimeoutException, "Connection failed"), httpx.TimeoutException,
(ssl.SSLError, "Connection failed"), ssl.SSLError,
], ],
) )
async def test_image_from_url_fails( async def test_image_from_url_fails(
@ -410,7 +410,6 @@ async def test_image_from_url_fails(
mqtt_mock_entry: MqttMockHAClientGenerator, mqtt_mock_entry: MqttMockHAClientGenerator,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
side_effect: Exception, side_effect: Exception,
log_text: str,
) -> None: ) -> None:
"""Test setup with minimum configuration.""" """Test setup with minimum configuration."""
respx.get("http://localhost/test.png").mock(side_effect=side_effect) respx.get("http://localhost/test.png").mock(side_effect=side_effect)
@ -436,7 +435,9 @@ async def test_image_from_url_fails(
# The image failed to load, the the last image update is registered # The image failed to load, the the last image update is registered
# but _last_image was set to `None` # but _last_image was set to `None`
assert state.state == "2023-04-01T00:00:00+00:00" assert state.state == "2023-04-01T00:00:00+00:00"
assert log_text in caplog.text client = await hass_client_no_auth()
resp = await client.get(state.attributes["entity_picture"])
assert resp.status == HTTPStatus.INTERNAL_SERVER_ERROR
@respx.mock @respx.mock