From d6fcdeac06dcec668c6ae521bd7676c6d76c6bf3 Mon Sep 17 00:00:00 2001 From: Joakim Plate Date: Mon, 14 Aug 2023 18:03:17 +0200 Subject: [PATCH] Avoid leaking backtrace on connection lost in arcam (#98277) * Avoid leaking backtrace on connection lost * Correct ruff error after rebase --- .../components/arcam_fmj/media_player.py | 45 +++++++++++-- tests/components/arcam_fmj/conftest.py | 3 + .../components/arcam_fmj/test_media_player.py | 63 +++++++++++++++++-- 3 files changed, 100 insertions(+), 11 deletions(-) diff --git a/homeassistant/components/arcam_fmj/media_player.py b/homeassistant/components/arcam_fmj/media_player.py index 0173005eb2f..12114ec04b8 100644 --- a/homeassistant/components/arcam_fmj/media_player.py +++ b/homeassistant/components/arcam_fmj/media_player.py @@ -1,10 +1,11 @@ """Arcam media player.""" from __future__ import annotations +import functools import logging from typing import Any -from arcam.fmj import SourceCodes +from arcam.fmj import ConnectionFailed, SourceCodes from arcam.fmj.state import State from homeassistant.components.media_player import ( @@ -19,6 +20,7 @@ from homeassistant.components.media_player.errors import BrowseError from homeassistant.config_entries import ConfigEntry from homeassistant.const import ATTR_ENTITY_ID from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -57,6 +59,21 @@ async def async_setup_entry( ) +def convert_exception(func): + """Return decorator to convert a connection error into a home assistant error.""" + + @functools.wraps(func) + async def _convert_exception(*args, **kwargs): + try: + return await func(*args, **kwargs) + except ConnectionFailed as exception: + raise HomeAssistantError( + f"Connection failed to device during {func}" + ) from exception + + return _convert_exception + + class ArcamFmj(MediaPlayerEntity): """Representation of a media device.""" @@ -105,7 +122,10 @@ class ArcamFmj(MediaPlayerEntity): async def async_added_to_hass(self) -> None: """Once registered, add listener for events.""" await self._state.start() - await self._state.update() + try: + await self._state.update() + except ConnectionFailed as connection: + _LOGGER.debug("Connection lost during addition: %s", connection) @callback def _data(host: str) -> None: @@ -137,13 +157,18 @@ class ArcamFmj(MediaPlayerEntity): async def async_update(self) -> None: """Force update of state.""" _LOGGER.debug("Update state %s", self.name) - await self._state.update() + try: + await self._state.update() + except ConnectionFailed as connection: + _LOGGER.debug("Connection lost during update: %s", connection) + @convert_exception async def async_mute_volume(self, mute: bool) -> None: """Send mute command.""" await self._state.set_mute(mute) self.async_write_ha_state() + @convert_exception async def async_select_source(self, source: str) -> None: """Select a specific source.""" try: @@ -155,31 +180,37 @@ class ArcamFmj(MediaPlayerEntity): await self._state.set_source(value) self.async_write_ha_state() + @convert_exception async def async_select_sound_mode(self, sound_mode: str) -> None: """Select a specific source.""" try: await self._state.set_decode_mode(sound_mode) - except (KeyError, ValueError): - _LOGGER.error("Unsupported sound_mode %s", sound_mode) - return + except (KeyError, ValueError) as exception: + raise HomeAssistantError( + f"Unsupported sound_mode {sound_mode}" + ) from exception self.async_write_ha_state() + @convert_exception async def async_set_volume_level(self, volume: float) -> None: """Set volume level, range 0..1.""" await self._state.set_volume(round(volume * 99.0)) self.async_write_ha_state() + @convert_exception async def async_volume_up(self) -> None: """Turn volume up for media player.""" await self._state.inc_volume() self.async_write_ha_state() + @convert_exception async def async_volume_down(self) -> None: """Turn volume up for media player.""" await self._state.dec_volume() self.async_write_ha_state() + @convert_exception async def async_turn_on(self) -> None: """Turn the media player on.""" if self._state.get_power() is not None: @@ -189,6 +220,7 @@ class ArcamFmj(MediaPlayerEntity): _LOGGER.debug("Firing event to turn on device") self.hass.bus.async_fire(EVENT_TURN_ON, {ATTR_ENTITY_ID: self.entity_id}) + @convert_exception async def async_turn_off(self) -> None: """Turn the media player off.""" await self._state.set_power(False) @@ -230,6 +262,7 @@ class ArcamFmj(MediaPlayerEntity): return root + @convert_exception async def async_play_media( self, media_type: MediaType | str, media_id: str, **kwargs: Any ) -> None: diff --git a/tests/components/arcam_fmj/conftest.py b/tests/components/arcam_fmj/conftest.py index 693cdc685c9..ba32951efe4 100644 --- a/tests/components/arcam_fmj/conftest.py +++ b/tests/components/arcam_fmj/conftest.py @@ -8,6 +8,7 @@ import pytest from homeassistant.components.arcam_fmj.const import DEFAULT_NAME from homeassistant.components.arcam_fmj.media_player import ArcamFmj from homeassistant.const import CONF_HOST, CONF_PORT +from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry, MockEntityPlatform @@ -94,6 +95,8 @@ async def player_setup_fixture(hass, state_1, state_2, client): if zone == 2: return state_2 + await async_setup_component(hass, "homeassistant", {}) + with patch("homeassistant.components.arcam_fmj.Client", return_value=client), patch( "homeassistant.components.arcam_fmj.media_player.State", side_effect=state_mock ), patch("homeassistant.components.arcam_fmj._run_client", return_value=None): diff --git a/tests/components/arcam_fmj/test_media_player.py b/tests/components/arcam_fmj/test_media_player.py index 2607ab817df..b9c86140cb9 100644 --- a/tests/components/arcam_fmj/test_media_player.py +++ b/tests/components/arcam_fmj/test_media_player.py @@ -2,14 +2,20 @@ from math import isclose from unittest.mock import ANY, PropertyMock, patch -from arcam.fmj import DecodeMode2CH, DecodeModeMCH, SourceCodes +from arcam.fmj import ConnectionFailed, DecodeMode2CH, DecodeModeMCH, SourceCodes import pytest +from homeassistant.components.homeassistant import ( + DOMAIN as HA_DOMAIN, + SERVICE_UPDATE_ENTITY, +) from homeassistant.components.media_player import ( ATTR_INPUT_SOURCE, + ATTR_MEDIA_VOLUME_LEVEL, ATTR_SOUND_MODE, ATTR_SOUND_MODE_LIST, SERVICE_SELECT_SOURCE, + SERVICE_VOLUME_SET, MediaType, ) from homeassistant.const import ( @@ -20,6 +26,7 @@ from homeassistant.const import ( ATTR_NAME, ) from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError from .conftest import MOCK_HOST, MOCK_UUID @@ -106,12 +113,33 @@ async def test_name(player) -> None: assert data.attributes["friendly_name"] == "Zone 1" -async def test_update(player, state) -> None: +async def test_update(hass: HomeAssistant, player_setup: str, state) -> None: """Test update.""" - await update(player, force_refresh=True) + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + service_data={ATTR_ENTITY_ID: player_setup}, + blocking=True, + ) state.update.assert_called_with() +async def test_update_lost( + hass: HomeAssistant, player_setup: str, state, caplog: pytest.LogCaptureFixture +) -> None: + """Test update, with connection loss is ignored.""" + state.update.side_effect = ConnectionFailed() + + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + service_data={ATTR_ENTITY_ID: player_setup}, + blocking=True, + ) + state.update.assert_called_with() + assert "Connection lost during update" in caplog.text + + @pytest.mark.parametrize( ("source", "value"), [("PVR", SourceCodes.PVR), ("BD", SourceCodes.BD), ("INVALID", None)], @@ -220,12 +248,37 @@ async def test_volume_level(player, state) -> None: @pytest.mark.parametrize(("volume", "call"), [(0.0, 0), (0.5, 50), (1.0, 99)]) -async def test_set_volume_level(player, state, volume, call) -> None: +async def test_set_volume_level( + hass: HomeAssistant, player_setup: str, state, volume, call +) -> None: """Test setting volume.""" - await player.async_set_volume_level(volume) + + await hass.services.async_call( + "media_player", + SERVICE_VOLUME_SET, + service_data={ATTR_ENTITY_ID: player_setup, ATTR_MEDIA_VOLUME_LEVEL: volume}, + blocking=True, + ) + state.set_volume.assert_called_with(call) +async def test_set_volume_level_lost( + hass: HomeAssistant, player_setup: str, state +) -> None: + """Test setting volume, with a lost connection.""" + + state.set_volume.side_effect = ConnectionFailed() + + with pytest.raises(HomeAssistantError): + await hass.services.async_call( + "media_player", + SERVICE_VOLUME_SET, + service_data={ATTR_ENTITY_ID: player_setup, ATTR_MEDIA_VOLUME_LEVEL: 0.0}, + blocking=True, + ) + + @pytest.mark.parametrize( ("source", "media_content_type"), [