Incorporate SourceManager into HEOS Coordinator (#136377)

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
Andrew Sayre 2025-01-24 04:56:41 -06:00 committed by GitHub
parent 50cf94ca9b
commit a3ba3bbb1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 180 additions and 237 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from datetime import timedelta
import logging
@ -13,7 +12,7 @@ from pyheos import Heos, HeosError, HeosPlayer, const as heos_const
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import (
@ -21,16 +20,9 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_send,
)
from homeassistant.helpers.typing import ConfigType
from homeassistant.util import Throttle
from . import services
from .const import (
COMMAND_RETRY_ATTEMPTS,
COMMAND_RETRY_DELAY,
DOMAIN,
SIGNAL_HEOS_PLAYER_ADDED,
SIGNAL_HEOS_UPDATED,
)
from .const import DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED
from .coordinator import HeosCoordinator
PLATFORMS = [Platform.MEDIA_PLAYER]
@ -48,7 +40,6 @@ class HeosRuntimeData:
coordinator: HeosCoordinator
group_manager: GroupManager
source_manager: SourceManager
players: dict[int, HeosPlayer]
@ -84,17 +75,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool
# Preserve existing logic until migrated into coordinator
controller = coordinator.heos
players = controller.players
favorites = coordinator.favorites
inputs = coordinator.inputs
source_manager = SourceManager(favorites, inputs)
source_manager.connect_update(hass, controller)
group_manager = GroupManager(hass, controller, players)
entry.runtime_data = HeosRuntimeData(
coordinator, group_manager, source_manager, players
)
entry.runtime_data = HeosRuntimeData(coordinator, group_manager, players)
group_manager.connect_update()
entry.async_on_unload(group_manager.disconnect_update)
@ -234,135 +218,3 @@ class GroupManager:
def group_membership(self):
"""Provide access to group members for player entities."""
return self._group_membership
class SourceManager:
"""Class that manages sources for players."""
def __init__(
self,
favorites,
inputs,
*,
retry_delay: int = COMMAND_RETRY_DELAY,
max_retry_attempts: int = COMMAND_RETRY_ATTEMPTS,
) -> None:
"""Init input manager."""
self.retry_delay = retry_delay
self.max_retry_attempts = max_retry_attempts
self.favorites = favorites
self.inputs = inputs
self.source_list = self._build_source_list()
def _build_source_list(self):
"""Build a single list of inputs from various types."""
source_list = []
source_list.extend([favorite.name for favorite in self.favorites.values()])
source_list.extend([source.name for source in self.inputs])
return source_list
async def play_source(self, source: str, player):
"""Determine type of source and play it."""
index = next(
(
index
for index, favorite in self.favorites.items()
if favorite.name == source
),
None,
)
if index is not None:
await player.play_preset_station(index)
return
input_source = next(
(
input_source
for input_source in self.inputs
if input_source.name == source
),
None,
)
if input_source is not None:
await player.play_input_source(input_source.media_id)
return
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="unknown_source",
translation_placeholders={"source": source},
)
def get_current_source(self, now_playing_media):
"""Determine current source from now playing media."""
# Match input by input_name:media_id
if now_playing_media.source_id == heos_const.MUSIC_SOURCE_AUX_INPUT:
return next(
(
input_source.name
for input_source in self.inputs
if input_source.media_id == now_playing_media.media_id
),
None,
)
# Try matching favorite by name:station or media_id:album_id
return next(
(
source.name
for source in self.favorites.values()
if source.name == now_playing_media.station
or source.media_id == now_playing_media.album_id
),
None,
)
@callback
def connect_update(self, hass: HomeAssistant, controller: Heos) -> None:
"""Connect listener for when sources change and signal player update.
EVENT_SOURCES_CHANGED is often raised multiple times in response to a
physical event therefore throttle it. Retrieving sources immediately
after the event may fail so retry.
"""
@Throttle(MIN_UPDATE_SOURCES)
async def get_sources():
retry_attempts = 0
while True:
try:
favorites = {}
if controller.is_signed_in:
favorites = await controller.get_favorites()
inputs = await controller.get_input_sources()
except HeosError as error:
if retry_attempts < self.max_retry_attempts:
retry_attempts += 1
_LOGGER.debug(
"Error retrieving sources and will retry: %s", error
)
await asyncio.sleep(self.retry_delay)
else:
_LOGGER.error("Unable to update sources: %s", error)
return None
else:
return favorites, inputs
async def _update_sources() -> None:
# If throttled, it will return None
if sources := await get_sources():
self.favorites, self.inputs = sources
self.source_list = self._build_source_list()
_LOGGER.debug("Sources updated due to changed event")
# Let players know to update
async_dispatcher_send(hass, SIGNAL_HEOS_UPDATED)
async def _on_controller_event(event: str, data: Any | None) -> None:
if event in (
heos_const.EVENT_SOURCES_CHANGED,
heos_const.EVENT_USER_CHANGED,
):
await _update_sources()
controller.add_on_connected(_update_sources)
controller.add_on_user_credentials_invalid(_update_sources)
controller.add_on_controller_event(_on_controller_event)

View File

@ -2,8 +2,6 @@
ATTR_PASSWORD = "password"
ATTR_USERNAME = "username"
COMMAND_RETRY_ATTEMPTS = 2
COMMAND_RETRY_DELAY = 1
DOMAIN = "heos"
SERVICE_SIGN_IN = "sign_in"
SERVICE_SIGN_OUT = "sign_out"

View File

@ -5,23 +5,28 @@ The coordinator is responsible for refreshing data in response to system-wide ev
entities to update. Entities subscribe to entity-specific updates within the entity class itself.
"""
from datetime import datetime, timedelta
import logging
from pyheos import (
Credentials,
Heos,
HeosError,
HeosNowPlayingMedia,
HeosOptions,
HeosPlayer,
MediaItem,
MediaType,
PlayerUpdateResult,
const,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.core import HassJob, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, ServiceValidationError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from . import DOMAIN
@ -50,8 +55,10 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
credentials=credentials,
)
)
self.favorites: dict[int, MediaItem] = {}
self.inputs: list[MediaItem] = []
self._update_sources_pending: bool = False
self._source_list: list[str] = []
self._favorites: dict[int, MediaItem] = {}
self._inputs: list[MediaItem] = []
super().__init__(hass, _LOGGER, config_entry=config_entry, name=DOMAIN)
async def async_setup(self) -> None:
@ -99,6 +106,7 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
async def _async_on_reconnected(self) -> None:
"""Handle when reconnected so resources are updated and entities marked available."""
await self._async_update_players()
await self._async_update_sources()
_LOGGER.warning("Successfully reconnected to HEOS host %s", self.host)
self.async_update_listeners()
@ -110,6 +118,31 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
assert data is not None
if data.updated_player_ids:
self._async_update_player_ids(data.updated_player_ids)
elif (
event in (const.EVENT_SOURCES_CHANGED, const.EVENT_USER_CHANGED)
and not self._update_sources_pending
):
# Update the sources after a brief delay as we may have received multiple qualifying
# events at once and devices cannot handle immediately attempting to refresh sources.
self._update_sources_pending = True
async def update_sources_job(_: datetime | None = None) -> None:
await self._async_update_sources()
self._update_sources_pending = False
self.async_update_listeners()
assert self.config_entry is not None
self.config_entry.async_on_unload(
async_call_later(
self.hass,
timedelta(seconds=1),
HassJob(
update_sources_job,
"heos_update_sources",
cancel_on_shutdown=True,
),
)
)
self.async_update_listeners()
def _async_update_player_ids(self, updated_player_ids: dict[int, int]) -> None:
@ -145,17 +178,24 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
async def _async_update_sources(self) -> None:
"""Build source list for entities."""
self._source_list.clear()
# Get favorites only if reportedly signed in.
if self.heos.is_signed_in:
try:
self.favorites = await self.heos.get_favorites()
self._favorites = await self.heos.get_favorites()
except HeosError as error:
_LOGGER.error("Unable to retrieve favorites: %s", error)
else:
self._source_list.extend(
favorite.name for favorite in self._favorites.values()
)
# Get input sources (across all devices in the HEOS system)
try:
self.inputs = await self.heos.get_input_sources()
self._inputs = await self.heos.get_input_sources()
except HeosError as error:
_LOGGER.error("Unable to retrieve input sources: %s", error)
else:
self._source_list.extend([source.name for source in self._inputs])
async def _async_update_players(self) -> None:
"""Update players after reconnection."""
@ -167,3 +207,61 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
# After reconnecting, player_id may have changed
if player_updates.updated_player_ids:
self._async_update_player_ids(player_updates.updated_player_ids)
@callback
def async_get_source_list(self) -> list[str]:
"""Return the list of sources for players."""
return list(self._source_list)
@callback
def async_get_favorite_index(self, name: str) -> int | None:
"""Get the index of a favorite by name."""
for index, favorite in self._favorites.items():
if favorite.name == name:
return index
return None
@callback
def async_get_current_source(
self, now_playing_media: HeosNowPlayingMedia
) -> str | None:
"""Determine current source from now playing media (either input source or favorite)."""
# Try matching input source
if now_playing_media.source_id == const.MUSIC_SOURCE_AUX_INPUT:
# If playing a remote input, name will match station
for input_source in self._inputs:
if input_source.name == now_playing_media.station:
return input_source.name
# If playing a local input, match media_id. This needs to be a second loop as media_id
# will match both local and remote inputs, so prioritize remote match by name first.
for input_source in self._inputs:
if input_source.media_id == now_playing_media.media_id:
return input_source.name
# Try matching favorite
if now_playing_media.type == MediaType.STATION:
# Some stations match on name:station, others match on media_id:album_id
for favorite in self._favorites.values():
if (
favorite.name == now_playing_media.station
or favorite.media_id == now_playing_media.album_id
):
return favorite.name
return None
async def async_play_source(self, source: str, player: HeosPlayer) -> None:
"""Determine type of source and play it."""
# Favorite
if (index := self.async_get_favorite_index(source)) is not None:
await player.play_preset_station(index)
return
# Input source
for input_source in self._inputs:
if input_source.name == source:
await player.play_media(input_source)
return
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="unknown_source",
translation_placeholders={"source": source},
)

View File

@ -40,7 +40,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from homeassistant.util.dt import utcnow
from . import GroupManager, HeosConfigEntry, SourceManager
from . import GroupManager, HeosConfigEntry
from .const import DOMAIN as HEOS_DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED
from .coordinator import HeosCoordinator
@ -97,7 +97,6 @@ async def async_setup_entry(
HeosMediaPlayer(
entry.runtime_data.coordinator,
player,
entry.runtime_data.source_manager,
entry.runtime_data.group_manager,
)
for player in players.values()
@ -144,13 +143,11 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
self,
coordinator: HeosCoordinator,
player: HeosPlayer,
source_manager: SourceManager,
group_manager: GroupManager,
) -> None:
"""Initialize."""
self._media_position_updated_at = None
self._player: HeosPlayer = player
self._source_manager = source_manager
self._group_manager = group_manager
self._attr_unique_id = str(player.player_id)
model_parts = player.model.split(maxsplit=1)
@ -164,8 +161,8 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
serial_number=player.serial, # Only available for some models
sw_version=player.version,
)
self._update_attributes()
super().__init__(coordinator, context=player.player_id)
self._update_attributes()
async def _player_update(self, event):
"""Handle player attribute updated."""
@ -181,6 +178,10 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
def _update_attributes(self) -> None:
"""Update core attributes of the media player."""
self._attr_source_list = self.coordinator.async_get_source_list()
self._attr_source = self.coordinator.async_get_current_source(
self._player.now_playing_media
)
self._attr_repeat = HEOS_HA_REPEAT_TYPE_MAP[self._player.repeat]
controls = self._player.now_playing_media.supported_controls
current_support = [CONTROL_TO_SUPPORT[control] for control in controls]
@ -304,14 +305,7 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
index = int(media_id)
except ValueError:
# Try finding index by name
index = next(
(
index
for index, favorite in self._source_manager.favorites.items()
if favorite.name == media_id
),
None,
)
index = self.coordinator.async_get_favorite_index(media_id)
if index is None:
raise ValueError(f"Invalid favorite '{media_id}'")
await self._player.play_preset_station(index)
@ -322,7 +316,7 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
@catch_action_error("select source")
async def async_select_source(self, source: str) -> None:
"""Select input source."""
await self._source_manager.play_source(source, self._player)
await self.coordinator.async_play_source(source, self._player)
@catch_action_error("set repeat")
async def async_set_repeat(self, repeat: RepeatMode) -> None:
@ -428,16 +422,6 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
"""Boolean if shuffle is enabled."""
return self._player.shuffle
@property
def source(self) -> str:
"""Name of the current input source."""
return self._source_manager.get_current_source(self._player.now_playing_media)
@property
def source_list(self) -> list[str]:
"""List of available input sources."""
return self._source_manager.source_list
@property
def state(self) -> MediaPlayerState:
"""State of the player."""

View File

@ -139,7 +139,7 @@ def players_fixture(quick_selects: dict[int, str]) -> dict[int, HeosPlayer]:
player.mute = AsyncMock()
player.pause = AsyncMock()
player.play = AsyncMock()
player.play_input_source = AsyncMock()
player.play_media = AsyncMock()
player.play_next = AsyncMock()
player.play_previous = AsyncMock()
player.play_preset_station = AsyncMock()
@ -193,8 +193,9 @@ def favorites_fixture() -> dict[int, MediaItem]:
@pytest.fixture(name="input_sources")
def input_sources_fixture() -> list[MediaItem]:
"""Create a set of input sources for testing."""
source = MediaItem(
source_id=1,
return [
MediaItem(
source_id=const.MUSIC_SOURCE_AUX_INPUT,
name="HEOS Drive - Line In 1",
media_id=const.INPUT_AUX_IN_1,
type=MediaType.STATION,
@ -202,8 +203,18 @@ def input_sources_fixture() -> list[MediaItem]:
browsable=False,
image_url="",
heos=None,
)
return [source]
),
MediaItem(
source_id=const.MUSIC_SOURCE_AUX_INPUT,
name="Speaker - Line In 1",
media_id=const.INPUT_AUX_IN_1,
type=MediaType.STATION,
playable=True,
browsable=False,
image_url="",
heos=None,
),
]
@pytest.fixture(name="discovery_data")

View File

@ -25,6 +25,7 @@
"Today's Hits Radio",
'Classical MPR (Classical Music)',
'HEOS Drive - Line In 1',
'Speaker - Line In 1',
]),
'supported_features': <MediaPlayerEntityFeature: 3079741>,
'volume_level': 0.25,

View File

@ -2,15 +2,7 @@
from typing import cast
from pyheos import (
CommandFailedError,
Heos,
HeosError,
HeosOptions,
SignalHeosEvent,
SignalType,
const,
)
from pyheos import Heos, HeosError, HeosOptions, SignalHeosEvent, SignalType
import pytest
from homeassistant.components.heos.const import DOMAIN
@ -163,27 +155,6 @@ async def test_unload_entry(
assert controller.disconnect.call_count == 1
async def test_update_sources_retry(
hass: HomeAssistant,
config_entry: MockConfigEntry,
controller: Heos,
) -> None:
"""Test update sources retries on failures to max attempts."""
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
controller.get_favorites.reset_mock()
controller.get_input_sources.reset_mock()
source_manager = config_entry.runtime_data.source_manager
source_manager.retry_delay = 0
source_manager.max_retry_attempts = 1
controller.get_favorites.side_effect = CommandFailedError("Test", "test", 0)
await controller.dispatcher.wait_send(
SignalType.CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {}
)
await hass.async_block_till_done()
assert controller.get_favorites.call_count == 2
async def test_device_info(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,

View File

@ -1,14 +1,17 @@
"""Tests for the Heos Media Player platform."""
from datetime import timedelta
import re
from typing import Any
from freezegun.api import FrozenDateTimeFactory
from pyheos import (
AddCriteriaType,
CommandFailedError,
Heos,
HeosError,
MediaItem,
MediaType as HeosMediaType,
PlayerUpdateResult,
PlayState,
RepeatType,
@ -63,7 +66,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.helpers import device_registry as dr, entity_registry as er
from tests.common import MockConfigEntry
from tests.common import MockConfigEntry, async_fire_time_changed
async def test_state_attributes(
@ -206,18 +209,21 @@ async def test_updates_from_sources_updated(
hass: HomeAssistant,
config_entry: MockConfigEntry,
controller: Heos,
input_sources: list[MediaItem],
freezer: FrozenDateTimeFactory,
) -> None:
"""Tests player updates from changes in sources list."""
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
player = controller.players[1]
input_sources.clear()
controller.get_input_sources.return_value = []
await player.heos.dispatcher.wait_send(
SignalType.CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {}
)
freezer.tick(timedelta(seconds=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == [
"Today's Hits Radio",
@ -288,6 +294,7 @@ async def test_updates_from_user_changed(
hass: HomeAssistant,
config_entry: MockConfigEntry,
controller: Heos,
freezer: FrozenDateTimeFactory,
) -> None:
"""Tests player updates from changes in user."""
config_entry.add_to_hass(hass)
@ -298,10 +305,15 @@ async def test_updates_from_user_changed(
await player.heos.dispatcher.wait_send(
SignalType.CONTROLLER_EVENT, const.EVENT_USER_CHANGED, None
)
freezer.tick(timedelta(seconds=1))
async_fire_time_changed(hass)
await hass.async_block_till_done()
state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == ["HEOS Drive - Line In 1"]
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == [
"HEOS Drive - Line In 1",
"Speaker - Line In 1",
]
async def test_clear_playlist(
@ -694,6 +706,7 @@ async def test_select_favorite(
)
player.play_preset_station.assert_called_once_with(1)
# Test state is matched by station name
player.now_playing_media.type = HeosMediaType.STATION
player.now_playing_media.station = favorite.name
await player.heos.dispatcher.wait_send(
SignalType.PLAYER_EVENT, player.player_id, const.EVENT_PLAYER_STATE_CHANGED
@ -723,6 +736,7 @@ async def test_select_radio_favorite(
)
player.play_preset_station.assert_called_once_with(2)
# Test state is matched by album id
player.now_playing_media.type = HeosMediaType.STATION
player.now_playing_media.station = "Classical"
player.now_playing_media.album_id = favorite.media_id
await player.heos.dispatcher.wait_send(
@ -762,37 +776,51 @@ async def test_select_radio_favorite_command_error(
player.play_preset_station.assert_called_once_with(2)
@pytest.mark.parametrize(
("source_name", "station"),
[
("HEOS Drive - Line In 1", "Line In 1"),
("Speaker - Line In 1", "Speaker - Line In 1"),
],
)
async def test_select_input_source(
hass: HomeAssistant,
config_entry: MockConfigEntry,
controller: Heos,
input_sources: list[MediaItem],
source_name: str,
station: str,
) -> None:
"""Tests selecting input source and state."""
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
player = controller.players[1]
# Test proper service called
input_source = input_sources[0]
await hass.services.async_call(
MEDIA_PLAYER_DOMAIN,
SERVICE_SELECT_SOURCE,
{
ATTR_ENTITY_ID: "media_player.test_player",
ATTR_INPUT_SOURCE: input_source.name,
ATTR_INPUT_SOURCE: source_name,
},
blocking=True,
)
player.play_input_source.assert_called_once_with(input_source.media_id)
# Test state is matched by media id
input_sources = next(
input_sources
for input_sources in input_sources
if input_sources.name == source_name
)
player.play_media.assert_called_once_with(input_sources)
# Update the now_playing_media to reflect play_media
player.now_playing_media.source_id = const.MUSIC_SOURCE_AUX_INPUT
player.now_playing_media.station = station
player.now_playing_media.media_id = const.INPUT_AUX_IN_1
await player.heos.dispatcher.wait_send(
SignalType.PLAYER_EVENT, player.player_id, const.EVENT_PLAYER_STATE_CHANGED
)
await hass.async_block_till_done()
state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE] == input_source.name
assert state.attributes[ATTR_INPUT_SOURCE] == source_name
async def test_select_input_unknown_raises(
@ -824,7 +852,7 @@ async def test_select_input_command_error(
await hass.config_entries.async_setup(config_entry.entry_id)
player = controller.players[1]
input_source = input_sources[0]
player.play_input_source.side_effect = CommandFailedError(None, "Failure", 1)
player.play_media.side_effect = CommandFailedError(None, "Failure", 1)
with pytest.raises(
HomeAssistantError,
match=re.escape("Unable to select source: Failure (1)"),
@ -838,7 +866,7 @@ async def test_select_input_command_error(
},
blocking=True,
)
player.play_input_source.assert_called_once_with(input_source.media_id)
player.play_media.assert_called_once_with(input_source)
async def test_unload_config_entry(