Incorporate SourceManager into HEOS Coordinator (#136377)

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
Andrew Sayre 2025-01-24 04:56:41 -06:00 committed by GitHub
parent 50cf94ca9b
commit a3ba3bbb1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 180 additions and 237 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
import logging import logging
@ -13,7 +12,7 @@ from pyheos import Heos, HeosError, HeosPlayer, const as heos_const
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import device_registry as dr from homeassistant.helpers import device_registry as dr
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
@ -21,16 +20,9 @@ from homeassistant.helpers.dispatcher import (
async_dispatcher_send, async_dispatcher_send,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.util import Throttle
from . import services from . import services
from .const import ( from .const import DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED
COMMAND_RETRY_ATTEMPTS,
COMMAND_RETRY_DELAY,
DOMAIN,
SIGNAL_HEOS_PLAYER_ADDED,
SIGNAL_HEOS_UPDATED,
)
from .coordinator import HeosCoordinator from .coordinator import HeosCoordinator
PLATFORMS = [Platform.MEDIA_PLAYER] PLATFORMS = [Platform.MEDIA_PLAYER]
@ -48,7 +40,6 @@ class HeosRuntimeData:
coordinator: HeosCoordinator coordinator: HeosCoordinator
group_manager: GroupManager group_manager: GroupManager
source_manager: SourceManager
players: dict[int, HeosPlayer] players: dict[int, HeosPlayer]
@ -84,17 +75,10 @@ async def async_setup_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool
# Preserve existing logic until migrated into coordinator # Preserve existing logic until migrated into coordinator
controller = coordinator.heos controller = coordinator.heos
players = controller.players players = controller.players
favorites = coordinator.favorites
inputs = coordinator.inputs
source_manager = SourceManager(favorites, inputs)
source_manager.connect_update(hass, controller)
group_manager = GroupManager(hass, controller, players) group_manager = GroupManager(hass, controller, players)
entry.runtime_data = HeosRuntimeData( entry.runtime_data = HeosRuntimeData(coordinator, group_manager, players)
coordinator, group_manager, source_manager, players
)
group_manager.connect_update() group_manager.connect_update()
entry.async_on_unload(group_manager.disconnect_update) entry.async_on_unload(group_manager.disconnect_update)
@ -234,135 +218,3 @@ class GroupManager:
def group_membership(self): def group_membership(self):
"""Provide access to group members for player entities.""" """Provide access to group members for player entities."""
return self._group_membership return self._group_membership
class SourceManager:
"""Class that manages sources for players."""
def __init__(
self,
favorites,
inputs,
*,
retry_delay: int = COMMAND_RETRY_DELAY,
max_retry_attempts: int = COMMAND_RETRY_ATTEMPTS,
) -> None:
"""Init input manager."""
self.retry_delay = retry_delay
self.max_retry_attempts = max_retry_attempts
self.favorites = favorites
self.inputs = inputs
self.source_list = self._build_source_list()
def _build_source_list(self):
"""Build a single list of inputs from various types."""
source_list = []
source_list.extend([favorite.name for favorite in self.favorites.values()])
source_list.extend([source.name for source in self.inputs])
return source_list
async def play_source(self, source: str, player):
"""Determine type of source and play it."""
index = next(
(
index
for index, favorite in self.favorites.items()
if favorite.name == source
),
None,
)
if index is not None:
await player.play_preset_station(index)
return
input_source = next(
(
input_source
for input_source in self.inputs
if input_source.name == source
),
None,
)
if input_source is not None:
await player.play_input_source(input_source.media_id)
return
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="unknown_source",
translation_placeholders={"source": source},
)
def get_current_source(self, now_playing_media):
"""Determine current source from now playing media."""
# Match input by input_name:media_id
if now_playing_media.source_id == heos_const.MUSIC_SOURCE_AUX_INPUT:
return next(
(
input_source.name
for input_source in self.inputs
if input_source.media_id == now_playing_media.media_id
),
None,
)
# Try matching favorite by name:station or media_id:album_id
return next(
(
source.name
for source in self.favorites.values()
if source.name == now_playing_media.station
or source.media_id == now_playing_media.album_id
),
None,
)
@callback
def connect_update(self, hass: HomeAssistant, controller: Heos) -> None:
"""Connect listener for when sources change and signal player update.
EVENT_SOURCES_CHANGED is often raised multiple times in response to a
physical event therefore throttle it. Retrieving sources immediately
after the event may fail so retry.
"""
@Throttle(MIN_UPDATE_SOURCES)
async def get_sources():
retry_attempts = 0
while True:
try:
favorites = {}
if controller.is_signed_in:
favorites = await controller.get_favorites()
inputs = await controller.get_input_sources()
except HeosError as error:
if retry_attempts < self.max_retry_attempts:
retry_attempts += 1
_LOGGER.debug(
"Error retrieving sources and will retry: %s", error
)
await asyncio.sleep(self.retry_delay)
else:
_LOGGER.error("Unable to update sources: %s", error)
return None
else:
return favorites, inputs
async def _update_sources() -> None:
# If throttled, it will return None
if sources := await get_sources():
self.favorites, self.inputs = sources
self.source_list = self._build_source_list()
_LOGGER.debug("Sources updated due to changed event")
# Let players know to update
async_dispatcher_send(hass, SIGNAL_HEOS_UPDATED)
async def _on_controller_event(event: str, data: Any | None) -> None:
if event in (
heos_const.EVENT_SOURCES_CHANGED,
heos_const.EVENT_USER_CHANGED,
):
await _update_sources()
controller.add_on_connected(_update_sources)
controller.add_on_user_credentials_invalid(_update_sources)
controller.add_on_controller_event(_on_controller_event)

View File

@ -2,8 +2,6 @@
ATTR_PASSWORD = "password" ATTR_PASSWORD = "password"
ATTR_USERNAME = "username" ATTR_USERNAME = "username"
COMMAND_RETRY_ATTEMPTS = 2
COMMAND_RETRY_DELAY = 1
DOMAIN = "heos" DOMAIN = "heos"
SERVICE_SIGN_IN = "sign_in" SERVICE_SIGN_IN = "sign_in"
SERVICE_SIGN_OUT = "sign_out" SERVICE_SIGN_OUT = "sign_out"

View File

@ -5,23 +5,28 @@ The coordinator is responsible for refreshing data in response to system-wide ev
entities to update. Entities subscribe to entity-specific updates within the entity class itself. entities to update. Entities subscribe to entity-specific updates within the entity class itself.
""" """
from datetime import datetime, timedelta
import logging import logging
from pyheos import ( from pyheos import (
Credentials, Credentials,
Heos, Heos,
HeosError, HeosError,
HeosNowPlayingMedia,
HeosOptions, HeosOptions,
HeosPlayer,
MediaItem, MediaItem,
MediaType,
PlayerUpdateResult, PlayerUpdateResult,
const, const,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, Platform from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HassJob, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady, ServiceValidationError
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from . import DOMAIN from . import DOMAIN
@ -50,8 +55,10 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
credentials=credentials, credentials=credentials,
) )
) )
self.favorites: dict[int, MediaItem] = {} self._update_sources_pending: bool = False
self.inputs: list[MediaItem] = [] self._source_list: list[str] = []
self._favorites: dict[int, MediaItem] = {}
self._inputs: list[MediaItem] = []
super().__init__(hass, _LOGGER, config_entry=config_entry, name=DOMAIN) super().__init__(hass, _LOGGER, config_entry=config_entry, name=DOMAIN)
async def async_setup(self) -> None: async def async_setup(self) -> None:
@ -99,6 +106,7 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
async def _async_on_reconnected(self) -> None: async def _async_on_reconnected(self) -> None:
"""Handle when reconnected so resources are updated and entities marked available.""" """Handle when reconnected so resources are updated and entities marked available."""
await self._async_update_players() await self._async_update_players()
await self._async_update_sources()
_LOGGER.warning("Successfully reconnected to HEOS host %s", self.host) _LOGGER.warning("Successfully reconnected to HEOS host %s", self.host)
self.async_update_listeners() self.async_update_listeners()
@ -110,6 +118,31 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
assert data is not None assert data is not None
if data.updated_player_ids: if data.updated_player_ids:
self._async_update_player_ids(data.updated_player_ids) self._async_update_player_ids(data.updated_player_ids)
elif (
event in (const.EVENT_SOURCES_CHANGED, const.EVENT_USER_CHANGED)
and not self._update_sources_pending
):
# Update the sources after a brief delay as we may have received multiple qualifying
# events at once and devices cannot handle immediately attempting to refresh sources.
self._update_sources_pending = True
async def update_sources_job(_: datetime | None = None) -> None:
await self._async_update_sources()
self._update_sources_pending = False
self.async_update_listeners()
assert self.config_entry is not None
self.config_entry.async_on_unload(
async_call_later(
self.hass,
timedelta(seconds=1),
HassJob(
update_sources_job,
"heos_update_sources",
cancel_on_shutdown=True,
),
)
)
self.async_update_listeners() self.async_update_listeners()
def _async_update_player_ids(self, updated_player_ids: dict[int, int]) -> None: def _async_update_player_ids(self, updated_player_ids: dict[int, int]) -> None:
@ -145,17 +178,24 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
async def _async_update_sources(self) -> None: async def _async_update_sources(self) -> None:
"""Build source list for entities.""" """Build source list for entities."""
self._source_list.clear()
# Get favorites only if reportedly signed in. # Get favorites only if reportedly signed in.
if self.heos.is_signed_in: if self.heos.is_signed_in:
try: try:
self.favorites = await self.heos.get_favorites() self._favorites = await self.heos.get_favorites()
except HeosError as error: except HeosError as error:
_LOGGER.error("Unable to retrieve favorites: %s", error) _LOGGER.error("Unable to retrieve favorites: %s", error)
else:
self._source_list.extend(
favorite.name for favorite in self._favorites.values()
)
# Get input sources (across all devices in the HEOS system) # Get input sources (across all devices in the HEOS system)
try: try:
self.inputs = await self.heos.get_input_sources() self._inputs = await self.heos.get_input_sources()
except HeosError as error: except HeosError as error:
_LOGGER.error("Unable to retrieve input sources: %s", error) _LOGGER.error("Unable to retrieve input sources: %s", error)
else:
self._source_list.extend([source.name for source in self._inputs])
async def _async_update_players(self) -> None: async def _async_update_players(self) -> None:
"""Update players after reconnection.""" """Update players after reconnection."""
@ -167,3 +207,61 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
# After reconnecting, player_id may have changed # After reconnecting, player_id may have changed
if player_updates.updated_player_ids: if player_updates.updated_player_ids:
self._async_update_player_ids(player_updates.updated_player_ids) self._async_update_player_ids(player_updates.updated_player_ids)
@callback
def async_get_source_list(self) -> list[str]:
"""Return the list of sources for players."""
return list(self._source_list)
@callback
def async_get_favorite_index(self, name: str) -> int | None:
"""Get the index of a favorite by name."""
for index, favorite in self._favorites.items():
if favorite.name == name:
return index
return None
@callback
def async_get_current_source(
self, now_playing_media: HeosNowPlayingMedia
) -> str | None:
"""Determine current source from now playing media (either input source or favorite)."""
# Try matching input source
if now_playing_media.source_id == const.MUSIC_SOURCE_AUX_INPUT:
# If playing a remote input, name will match station
for input_source in self._inputs:
if input_source.name == now_playing_media.station:
return input_source.name
# If playing a local input, match media_id. This needs to be a second loop as media_id
# will match both local and remote inputs, so prioritize remote match by name first.
for input_source in self._inputs:
if input_source.media_id == now_playing_media.media_id:
return input_source.name
# Try matching favorite
if now_playing_media.type == MediaType.STATION:
# Some stations match on name:station, others match on media_id:album_id
for favorite in self._favorites.values():
if (
favorite.name == now_playing_media.station
or favorite.media_id == now_playing_media.album_id
):
return favorite.name
return None
async def async_play_source(self, source: str, player: HeosPlayer) -> None:
"""Determine type of source and play it."""
# Favorite
if (index := self.async_get_favorite_index(source)) is not None:
await player.play_preset_station(index)
return
# Input source
for input_source in self._inputs:
if input_source.name == source:
await player.play_media(input_source)
return
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="unknown_source",
translation_placeholders={"source": source},
)

View File

@ -40,7 +40,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from . import GroupManager, HeosConfigEntry, SourceManager from . import GroupManager, HeosConfigEntry
from .const import DOMAIN as HEOS_DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED from .const import DOMAIN as HEOS_DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED
from .coordinator import HeosCoordinator from .coordinator import HeosCoordinator
@ -97,7 +97,6 @@ async def async_setup_entry(
HeosMediaPlayer( HeosMediaPlayer(
entry.runtime_data.coordinator, entry.runtime_data.coordinator,
player, player,
entry.runtime_data.source_manager,
entry.runtime_data.group_manager, entry.runtime_data.group_manager,
) )
for player in players.values() for player in players.values()
@ -144,13 +143,11 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
self, self,
coordinator: HeosCoordinator, coordinator: HeosCoordinator,
player: HeosPlayer, player: HeosPlayer,
source_manager: SourceManager,
group_manager: GroupManager, group_manager: GroupManager,
) -> None: ) -> None:
"""Initialize.""" """Initialize."""
self._media_position_updated_at = None self._media_position_updated_at = None
self._player: HeosPlayer = player self._player: HeosPlayer = player
self._source_manager = source_manager
self._group_manager = group_manager self._group_manager = group_manager
self._attr_unique_id = str(player.player_id) self._attr_unique_id = str(player.player_id)
model_parts = player.model.split(maxsplit=1) model_parts = player.model.split(maxsplit=1)
@ -164,8 +161,8 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
serial_number=player.serial, # Only available for some models serial_number=player.serial, # Only available for some models
sw_version=player.version, sw_version=player.version,
) )
self._update_attributes()
super().__init__(coordinator, context=player.player_id) super().__init__(coordinator, context=player.player_id)
self._update_attributes()
async def _player_update(self, event): async def _player_update(self, event):
"""Handle player attribute updated.""" """Handle player attribute updated."""
@ -181,6 +178,10 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
def _update_attributes(self) -> None: def _update_attributes(self) -> None:
"""Update core attributes of the media player.""" """Update core attributes of the media player."""
self._attr_source_list = self.coordinator.async_get_source_list()
self._attr_source = self.coordinator.async_get_current_source(
self._player.now_playing_media
)
self._attr_repeat = HEOS_HA_REPEAT_TYPE_MAP[self._player.repeat] self._attr_repeat = HEOS_HA_REPEAT_TYPE_MAP[self._player.repeat]
controls = self._player.now_playing_media.supported_controls controls = self._player.now_playing_media.supported_controls
current_support = [CONTROL_TO_SUPPORT[control] for control in controls] current_support = [CONTROL_TO_SUPPORT[control] for control in controls]
@ -304,14 +305,7 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
index = int(media_id) index = int(media_id)
except ValueError: except ValueError:
# Try finding index by name # Try finding index by name
index = next( index = self.coordinator.async_get_favorite_index(media_id)
(
index
for index, favorite in self._source_manager.favorites.items()
if favorite.name == media_id
),
None,
)
if index is None: if index is None:
raise ValueError(f"Invalid favorite '{media_id}'") raise ValueError(f"Invalid favorite '{media_id}'")
await self._player.play_preset_station(index) await self._player.play_preset_station(index)
@ -322,7 +316,7 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
@catch_action_error("select source") @catch_action_error("select source")
async def async_select_source(self, source: str) -> None: async def async_select_source(self, source: str) -> None:
"""Select input source.""" """Select input source."""
await self._source_manager.play_source(source, self._player) await self.coordinator.async_play_source(source, self._player)
@catch_action_error("set repeat") @catch_action_error("set repeat")
async def async_set_repeat(self, repeat: RepeatMode) -> None: async def async_set_repeat(self, repeat: RepeatMode) -> None:
@ -428,16 +422,6 @@ class HeosMediaPlayer(CoordinatorEntity[HeosCoordinator], MediaPlayerEntity):
"""Boolean if shuffle is enabled.""" """Boolean if shuffle is enabled."""
return self._player.shuffle return self._player.shuffle
@property
def source(self) -> str:
"""Name of the current input source."""
return self._source_manager.get_current_source(self._player.now_playing_media)
@property
def source_list(self) -> list[str]:
"""List of available input sources."""
return self._source_manager.source_list
@property @property
def state(self) -> MediaPlayerState: def state(self) -> MediaPlayerState:
"""State of the player.""" """State of the player."""

View File

@ -139,7 +139,7 @@ def players_fixture(quick_selects: dict[int, str]) -> dict[int, HeosPlayer]:
player.mute = AsyncMock() player.mute = AsyncMock()
player.pause = AsyncMock() player.pause = AsyncMock()
player.play = AsyncMock() player.play = AsyncMock()
player.play_input_source = AsyncMock() player.play_media = AsyncMock()
player.play_next = AsyncMock() player.play_next = AsyncMock()
player.play_previous = AsyncMock() player.play_previous = AsyncMock()
player.play_preset_station = AsyncMock() player.play_preset_station = AsyncMock()
@ -193,17 +193,28 @@ def favorites_fixture() -> dict[int, MediaItem]:
@pytest.fixture(name="input_sources") @pytest.fixture(name="input_sources")
def input_sources_fixture() -> list[MediaItem]: def input_sources_fixture() -> list[MediaItem]:
"""Create a set of input sources for testing.""" """Create a set of input sources for testing."""
source = MediaItem( return [
source_id=1, MediaItem(
name="HEOS Drive - Line In 1", source_id=const.MUSIC_SOURCE_AUX_INPUT,
media_id=const.INPUT_AUX_IN_1, name="HEOS Drive - Line In 1",
type=MediaType.STATION, media_id=const.INPUT_AUX_IN_1,
playable=True, type=MediaType.STATION,
browsable=False, playable=True,
image_url="", browsable=False,
heos=None, image_url="",
) heos=None,
return [source] ),
MediaItem(
source_id=const.MUSIC_SOURCE_AUX_INPUT,
name="Speaker - Line In 1",
media_id=const.INPUT_AUX_IN_1,
type=MediaType.STATION,
playable=True,
browsable=False,
image_url="",
heos=None,
),
]
@pytest.fixture(name="discovery_data") @pytest.fixture(name="discovery_data")

View File

@ -25,6 +25,7 @@
"Today's Hits Radio", "Today's Hits Radio",
'Classical MPR (Classical Music)', 'Classical MPR (Classical Music)',
'HEOS Drive - Line In 1', 'HEOS Drive - Line In 1',
'Speaker - Line In 1',
]), ]),
'supported_features': <MediaPlayerEntityFeature: 3079741>, 'supported_features': <MediaPlayerEntityFeature: 3079741>,
'volume_level': 0.25, 'volume_level': 0.25,

View File

@ -2,15 +2,7 @@
from typing import cast from typing import cast
from pyheos import ( from pyheos import Heos, HeosError, HeosOptions, SignalHeosEvent, SignalType
CommandFailedError,
Heos,
HeosError,
HeosOptions,
SignalHeosEvent,
SignalType,
const,
)
import pytest import pytest
from homeassistant.components.heos.const import DOMAIN from homeassistant.components.heos.const import DOMAIN
@ -163,27 +155,6 @@ async def test_unload_entry(
assert controller.disconnect.call_count == 1 assert controller.disconnect.call_count == 1
async def test_update_sources_retry(
hass: HomeAssistant,
config_entry: MockConfigEntry,
controller: Heos,
) -> None:
"""Test update sources retries on failures to max attempts."""
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
controller.get_favorites.reset_mock()
controller.get_input_sources.reset_mock()
source_manager = config_entry.runtime_data.source_manager
source_manager.retry_delay = 0
source_manager.max_retry_attempts = 1
controller.get_favorites.side_effect = CommandFailedError("Test", "test", 0)
await controller.dispatcher.wait_send(
SignalType.CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {}
)
await hass.async_block_till_done()
assert controller.get_favorites.call_count == 2
async def test_device_info( async def test_device_info(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,

View File

@ -1,14 +1,17 @@
"""Tests for the Heos Media Player platform.""" """Tests for the Heos Media Player platform."""
from datetime import timedelta
import re import re
from typing import Any from typing import Any
from freezegun.api import FrozenDateTimeFactory
from pyheos import ( from pyheos import (
AddCriteriaType, AddCriteriaType,
CommandFailedError, CommandFailedError,
Heos, Heos,
HeosError, HeosError,
MediaItem, MediaItem,
MediaType as HeosMediaType,
PlayerUpdateResult, PlayerUpdateResult,
PlayState, PlayState,
RepeatType, RepeatType,
@ -63,7 +66,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers import device_registry as dr, entity_registry as er
from tests.common import MockConfigEntry from tests.common import MockConfigEntry, async_fire_time_changed
async def test_state_attributes( async def test_state_attributes(
@ -206,18 +209,21 @@ async def test_updates_from_sources_updated(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: MockConfigEntry, config_entry: MockConfigEntry,
controller: Heos, controller: Heos,
input_sources: list[MediaItem], freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Tests player updates from changes in sources list.""" """Tests player updates from changes in sources list."""
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
player = controller.players[1] player = controller.players[1]
input_sources.clear() controller.get_input_sources.return_value = []
await player.heos.dispatcher.wait_send( await player.heos.dispatcher.wait_send(
SignalType.CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {} SignalType.CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {}
) )
freezer.tick(timedelta(seconds=1))
async_fire_time_changed(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("media_player.test_player") state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == [ assert state.attributes[ATTR_INPUT_SOURCE_LIST] == [
"Today's Hits Radio", "Today's Hits Radio",
@ -288,6 +294,7 @@ async def test_updates_from_user_changed(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: MockConfigEntry, config_entry: MockConfigEntry,
controller: Heos, controller: Heos,
freezer: FrozenDateTimeFactory,
) -> None: ) -> None:
"""Tests player updates from changes in user.""" """Tests player updates from changes in user."""
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
@ -298,10 +305,15 @@ async def test_updates_from_user_changed(
await player.heos.dispatcher.wait_send( await player.heos.dispatcher.wait_send(
SignalType.CONTROLLER_EVENT, const.EVENT_USER_CHANGED, None SignalType.CONTROLLER_EVENT, const.EVENT_USER_CHANGED, None
) )
freezer.tick(timedelta(seconds=1))
async_fire_time_changed(hass)
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("media_player.test_player") state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == ["HEOS Drive - Line In 1"] assert state.attributes[ATTR_INPUT_SOURCE_LIST] == [
"HEOS Drive - Line In 1",
"Speaker - Line In 1",
]
async def test_clear_playlist( async def test_clear_playlist(
@ -694,6 +706,7 @@ async def test_select_favorite(
) )
player.play_preset_station.assert_called_once_with(1) player.play_preset_station.assert_called_once_with(1)
# Test state is matched by station name # Test state is matched by station name
player.now_playing_media.type = HeosMediaType.STATION
player.now_playing_media.station = favorite.name player.now_playing_media.station = favorite.name
await player.heos.dispatcher.wait_send( await player.heos.dispatcher.wait_send(
SignalType.PLAYER_EVENT, player.player_id, const.EVENT_PLAYER_STATE_CHANGED SignalType.PLAYER_EVENT, player.player_id, const.EVENT_PLAYER_STATE_CHANGED
@ -723,6 +736,7 @@ async def test_select_radio_favorite(
) )
player.play_preset_station.assert_called_once_with(2) player.play_preset_station.assert_called_once_with(2)
# Test state is matched by album id # Test state is matched by album id
player.now_playing_media.type = HeosMediaType.STATION
player.now_playing_media.station = "Classical" player.now_playing_media.station = "Classical"
player.now_playing_media.album_id = favorite.media_id player.now_playing_media.album_id = favorite.media_id
await player.heos.dispatcher.wait_send( await player.heos.dispatcher.wait_send(
@ -762,37 +776,51 @@ async def test_select_radio_favorite_command_error(
player.play_preset_station.assert_called_once_with(2) player.play_preset_station.assert_called_once_with(2)
@pytest.mark.parametrize(
("source_name", "station"),
[
("HEOS Drive - Line In 1", "Line In 1"),
("Speaker - Line In 1", "Speaker - Line In 1"),
],
)
async def test_select_input_source( async def test_select_input_source(
hass: HomeAssistant, hass: HomeAssistant,
config_entry: MockConfigEntry, config_entry: MockConfigEntry,
controller: Heos, controller: Heos,
input_sources: list[MediaItem], input_sources: list[MediaItem],
source_name: str,
station: str,
) -> None: ) -> None:
"""Tests selecting input source and state.""" """Tests selecting input source and state."""
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id) assert await hass.config_entries.async_setup(config_entry.entry_id)
player = controller.players[1] player = controller.players[1]
# Test proper service called
input_source = input_sources[0]
await hass.services.async_call( await hass.services.async_call(
MEDIA_PLAYER_DOMAIN, MEDIA_PLAYER_DOMAIN,
SERVICE_SELECT_SOURCE, SERVICE_SELECT_SOURCE,
{ {
ATTR_ENTITY_ID: "media_player.test_player", ATTR_ENTITY_ID: "media_player.test_player",
ATTR_INPUT_SOURCE: input_source.name, ATTR_INPUT_SOURCE: source_name,
}, },
blocking=True, blocking=True,
) )
player.play_input_source.assert_called_once_with(input_source.media_id) input_sources = next(
# Test state is matched by media id input_sources
for input_sources in input_sources
if input_sources.name == source_name
)
player.play_media.assert_called_once_with(input_sources)
# Update the now_playing_media to reflect play_media
player.now_playing_media.source_id = const.MUSIC_SOURCE_AUX_INPUT player.now_playing_media.source_id = const.MUSIC_SOURCE_AUX_INPUT
player.now_playing_media.station = station
player.now_playing_media.media_id = const.INPUT_AUX_IN_1 player.now_playing_media.media_id = const.INPUT_AUX_IN_1
await player.heos.dispatcher.wait_send( await player.heos.dispatcher.wait_send(
SignalType.PLAYER_EVENT, player.player_id, const.EVENT_PLAYER_STATE_CHANGED SignalType.PLAYER_EVENT, player.player_id, const.EVENT_PLAYER_STATE_CHANGED
) )
await hass.async_block_till_done() await hass.async_block_till_done()
state = hass.states.get("media_player.test_player") state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE] == input_source.name assert state.attributes[ATTR_INPUT_SOURCE] == source_name
async def test_select_input_unknown_raises( async def test_select_input_unknown_raises(
@ -824,7 +852,7 @@ async def test_select_input_command_error(
await hass.config_entries.async_setup(config_entry.entry_id) await hass.config_entries.async_setup(config_entry.entry_id)
player = controller.players[1] player = controller.players[1]
input_source = input_sources[0] input_source = input_sources[0]
player.play_input_source.side_effect = CommandFailedError(None, "Failure", 1) player.play_media.side_effect = CommandFailedError(None, "Failure", 1)
with pytest.raises( with pytest.raises(
HomeAssistantError, HomeAssistantError,
match=re.escape("Unable to select source: Failure (1)"), match=re.escape("Unable to select source: Failure (1)"),
@ -838,7 +866,7 @@ async def test_select_input_command_error(
}, },
blocking=True, blocking=True,
) )
player.play_input_source.assert_called_once_with(input_source.media_id) player.play_media.assert_called_once_with(input_source)
async def test_unload_config_entry( async def test_unload_config_entry(