Use runtime data in HEOS (#132030)

* Adopt runtime_data

* Fix missing variable assignment

* Address PR feedback
This commit is contained in:
Andrew Sayre 2024-12-02 01:19:43 -06:00 committed by GitHub
parent 4eb5734d73
commit 4eb75a56e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 62 additions and 88 deletions

View File

@ -3,10 +3,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
import logging import logging
from pyheos import Heos, HeosError, const as heos_const from pyheos import Heos, HeosError, HeosPlayer, const as heos_const
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
@ -27,10 +28,6 @@ from .config_flow import format_title
from .const import ( from .const import (
COMMAND_RETRY_ATTEMPTS, COMMAND_RETRY_ATTEMPTS,
COMMAND_RETRY_DELAY, COMMAND_RETRY_DELAY,
DATA_CONTROLLER_MANAGER,
DATA_ENTITY_ID_MAP,
DATA_GROUP_MANAGER,
DATA_SOURCE_MANAGER,
DOMAIN, DOMAIN,
SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_PLAYER_ADDED,
SIGNAL_HEOS_UPDATED, SIGNAL_HEOS_UPDATED,
@ -51,6 +48,19 @@ MIN_UPDATE_SOURCES = timedelta(seconds=1)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@dataclass
class HeosRuntimeData:
"""Runtime data and coordinators for HEOS config entries."""
controller_manager: ControllerManager
group_manager: GroupManager
source_manager: SourceManager
players: dict[int, HeosPlayer]
type HeosConfigEntry = ConfigEntry[HeosRuntimeData]
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the HEOS component.""" """Set up the HEOS component."""
if DOMAIN not in config: if DOMAIN not in config:
@ -75,7 +85,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool:
"""Initialize config entry which represents the HEOS controller.""" """Initialize config entry which represents the HEOS controller."""
# For backwards compat # For backwards compat
if entry.unique_id is None: if entry.unique_id is None:
@ -128,17 +138,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
source_manager = SourceManager(favorites, inputs) source_manager = SourceManager(favorites, inputs)
source_manager.connect_update(hass, controller) source_manager.connect_update(hass, controller)
group_manager = GroupManager(hass, controller) group_manager = GroupManager(hass, controller, players)
hass.data[DOMAIN] = { entry.runtime_data = HeosRuntimeData(
DATA_CONTROLLER_MANAGER: controller_manager, controller_manager, group_manager, source_manager, players
DATA_GROUP_MANAGER: group_manager, )
DATA_SOURCE_MANAGER: source_manager,
Platform.MEDIA_PLAYER: players,
# Maps player_id to entity_id. Populated by the individual
# HeosMediaPlayer entities.
DATA_ENTITY_ID_MAP: {},
}
services.register(hass, controller) services.register(hass, controller)
group_manager.connect_update() group_manager.connect_update()
@ -149,11 +153,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
controller_manager = hass.data[DOMAIN][DATA_CONTROLLER_MANAGER] await entry.runtime_data.controller_manager.disconnect()
await controller_manager.disconnect()
hass.data.pop(DOMAIN)
services.remove(hass) services.remove(hass)
@ -246,21 +248,25 @@ class ControllerManager:
class GroupManager: class GroupManager:
"""Class that manages HEOS groups.""" """Class that manages HEOS groups."""
def __init__(self, hass, controller): def __init__(
self, hass: HomeAssistant, controller: Heos, players: dict[int, HeosPlayer]
) -> None:
"""Init group manager.""" """Init group manager."""
self._hass = hass self._hass = hass
self._group_membership = {} self._group_membership: dict[str, str] = {}
self._disconnect_player_added = None self._disconnect_player_added = None
self._initialized = False self._initialized = False
self.controller = controller self.controller = controller
self.players = players
self.entity_id_map: dict[int, str] = {}
def _get_entity_id_to_player_id_map(self) -> dict: def _get_entity_id_to_player_id_map(self) -> dict:
"""Return mapping of all HeosMediaPlayer entity_ids to player_ids.""" """Return mapping of all HeosMediaPlayer entity_ids to player_ids."""
return {v: k for k, v in self._hass.data[DOMAIN][DATA_ENTITY_ID_MAP].items()} return {v: k for k, v in self.entity_id_map.items()}
async def async_get_group_membership(self): async def async_get_group_membership(self) -> dict[str, list[str]]:
"""Return all group members for each player as entity_ids.""" """Return all group members for each player as entity_ids."""
group_info_by_entity_id = { group_info_by_entity_id: dict[str, list[str]] = {
player_entity_id: [] player_entity_id: []
for player_entity_id in self._get_entity_id_to_player_id_map() for player_entity_id in self._get_entity_id_to_player_id_map()
} }
@ -271,7 +277,7 @@ class GroupManager:
_LOGGER.error("Unable to get HEOS group info: %s", err) _LOGGER.error("Unable to get HEOS group info: %s", err)
return group_info_by_entity_id return group_info_by_entity_id
player_id_to_entity_id_map = self._hass.data[DOMAIN][DATA_ENTITY_ID_MAP] player_id_to_entity_id_map = self.entity_id_map
for group in groups.values(): for group in groups.values():
leader_entity_id = player_id_to_entity_id_map.get(group.leader.player_id) leader_entity_id = player_id_to_entity_id_map.get(group.leader.player_id)
member_entity_ids = [ member_entity_ids = [
@ -282,9 +288,9 @@ class GroupManager:
# Make sure the group leader is always the first element # Make sure the group leader is always the first element
group_info = [leader_entity_id, *member_entity_ids] group_info = [leader_entity_id, *member_entity_ids]
if leader_entity_id: if leader_entity_id:
group_info_by_entity_id[leader_entity_id] = group_info group_info_by_entity_id[leader_entity_id] = group_info # type: ignore[assignment]
for member_entity_id in member_entity_ids: for member_entity_id in member_entity_ids:
group_info_by_entity_id[member_entity_id] = group_info group_info_by_entity_id[member_entity_id] = group_info # type: ignore[assignment]
return group_info_by_entity_id return group_info_by_entity_id
@ -358,13 +364,9 @@ class GroupManager:
# When adding a new HEOS player we need to update the groups. # When adding a new HEOS player we need to update the groups.
async def _async_handle_player_added(): async def _async_handle_player_added():
# Avoid calling async_update_groups when `DATA_ENTITY_ID_MAP` has not been # Avoid calling async_update_groups when the entity_id map has not been
# fully populated yet. This may only happen during early startup. # fully populated yet. This may only happen during early startup.
if ( if len(self.players) <= len(self.entity_id_map) and not self._initialized:
len(self._hass.data[DOMAIN][Platform.MEDIA_PLAYER])
<= len(self._hass.data[DOMAIN][DATA_ENTITY_ID_MAP])
and not self._initialized
):
self._initialized = True self._initialized = True
await self.async_update_groups(SIGNAL_HEOS_PLAYER_ADDED) await self.async_update_groups(SIGNAL_HEOS_PLAYER_ADDED)

View File

@ -4,10 +4,6 @@ ATTR_PASSWORD = "password"
ATTR_USERNAME = "username" ATTR_USERNAME = "username"
COMMAND_RETRY_ATTEMPTS = 2 COMMAND_RETRY_ATTEMPTS = 2
COMMAND_RETRY_DELAY = 1 COMMAND_RETRY_DELAY = 1
DATA_CONTROLLER_MANAGER = "controller"
DATA_ENTITY_ID_MAP = "entity_id_map"
DATA_GROUP_MANAGER = "group_manager"
DATA_SOURCE_MANAGER = "source_manager"
DATA_DISCOVERED_HOSTS = "heos_discovered_hosts" DATA_DISCOVERED_HOSTS = "heos_discovered_hosts"
DOMAIN = "heos" DOMAIN = "heos"
SERVICE_SIGN_IN = "sign_in" SERVICE_SIGN_IN = "sign_in"

View File

@ -13,7 +13,6 @@ from pyheos import HeosError, const as heos_const
from homeassistant.components import media_source from homeassistant.components import media_source
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_MEDIA_ENQUEUE, ATTR_MEDIA_ENQUEUE,
DOMAIN as MEDIA_PLAYER_DOMAIN,
BrowseMedia, BrowseMedia,
MediaPlayerEnqueue, MediaPlayerEnqueue,
MediaPlayerEntity, MediaPlayerEntity,
@ -22,7 +21,6 @@ from homeassistant.components.media_player import (
MediaType, MediaType,
async_process_play_media_url, async_process_play_media_url,
) )
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
@ -32,14 +30,8 @@ from homeassistant.helpers.dispatcher import (
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from .const import ( from . import GroupManager, HeosConfigEntry, SourceManager
DATA_ENTITY_ID_MAP, from .const import DOMAIN as HEOS_DOMAIN, SIGNAL_HEOS_PLAYER_ADDED, SIGNAL_HEOS_UPDATED
DATA_GROUP_MANAGER,
DATA_SOURCE_MANAGER,
DOMAIN as HEOS_DOMAIN,
SIGNAL_HEOS_PLAYER_ADDED,
SIGNAL_HEOS_UPDATED,
)
BASE_SUPPORTED_FEATURES = ( BASE_SUPPORTED_FEATURES = (
MediaPlayerEntityFeature.VOLUME_MUTE MediaPlayerEntityFeature.VOLUME_MUTE
@ -80,11 +72,16 @@ _LOGGER = logging.getLogger(__name__)
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback hass: HomeAssistant, entry: HeosConfigEntry, async_add_entities: AddEntitiesCallback
) -> None: ) -> None:
"""Add media players for a config entry.""" """Add media players for a config entry."""
players = hass.data[HEOS_DOMAIN][MEDIA_PLAYER_DOMAIN] players = entry.runtime_data.players
devices = [HeosMediaPlayer(player) for player in players.values()] devices = [
HeosMediaPlayer(
player, entry.runtime_data.source_manager, entry.runtime_data.group_manager
)
for player in players.values()
]
async_add_entities(devices, True) async_add_entities(devices, True)
@ -120,13 +117,15 @@ class HeosMediaPlayer(MediaPlayerEntity):
_attr_has_entity_name = True _attr_has_entity_name = True
_attr_name = None _attr_name = None
def __init__(self, player): def __init__(
self, player, source_manager: SourceManager, group_manager: GroupManager
) -> None:
"""Initialize.""" """Initialize."""
self._media_position_updated_at = None self._media_position_updated_at = None
self._player = player self._player = player
self._signals = [] self._signals: list = []
self._source_manager = None self._source_manager = source_manager
self._group_manager = None self._group_manager = group_manager
self._attr_unique_id = str(player.player_id) self._attr_unique_id = str(player.player_id)
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
identifiers={(HEOS_DOMAIN, player.player_id)}, identifiers={(HEOS_DOMAIN, player.player_id)},
@ -161,9 +160,7 @@ class HeosMediaPlayer(MediaPlayerEntity):
async_dispatcher_connect(self.hass, SIGNAL_HEOS_UPDATED, self._heos_updated) async_dispatcher_connect(self.hass, SIGNAL_HEOS_UPDATED, self._heos_updated)
) )
# Register this player's entity_id so it can be resolved by the group manager # Register this player's entity_id so it can be resolved by the group manager
self.hass.data[HEOS_DOMAIN][DATA_ENTITY_ID_MAP][self._player.player_id] = ( self._group_manager.entity_id_map[self._player.player_id] = self.entity_id
self.entity_id
)
async_dispatcher_send(self.hass, SIGNAL_HEOS_PLAYER_ADDED) async_dispatcher_send(self.hass, SIGNAL_HEOS_PLAYER_ADDED)
@log_command_error("clear playlist") @log_command_error("clear playlist")
@ -294,12 +291,6 @@ class HeosMediaPlayer(MediaPlayerEntity):
ior, current_support, BASE_SUPPORTED_FEATURES ior, current_support, BASE_SUPPORTED_FEATURES
) )
if self._group_manager is None:
self._group_manager = self.hass.data[HEOS_DOMAIN][DATA_GROUP_MANAGER]
if self._source_manager is None:
self._source_manager = self.hass.data[HEOS_DOMAIN][DATA_SOURCE_MANAGER]
@log_command_error("unjoin_player") @log_command_error("unjoin_player")
async def async_unjoin_player(self) -> None: async def async_unjoin_player(self) -> None:
"""Remove this player from any group.""" """Remove this player from any group."""

View File

@ -8,15 +8,11 @@ import pytest
from homeassistant.components.heos import ( from homeassistant.components.heos import (
ControllerManager, ControllerManager,
HeosRuntimeData,
async_setup_entry, async_setup_entry,
async_unload_entry, async_unload_entry,
) )
from homeassistant.components.heos.const import ( from homeassistant.components.heos.const import DOMAIN
DATA_CONTROLLER_MANAGER,
DATA_SOURCE_MANAGER,
DOMAIN,
)
from homeassistant.components.media_player import DOMAIN as MEDIA_PLAYER_DOMAIN
from homeassistant.const import CONF_HOST from homeassistant.const import CONF_HOST
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
@ -92,10 +88,6 @@ async def test_async_setup_entry_loads_platforms(
assert controller.get_favorites.call_count == 1 assert controller.get_favorites.call_count == 1
assert controller.get_input_sources.call_count == 1 assert controller.get_input_sources.call_count == 1
controller.disconnect.assert_not_called() controller.disconnect.assert_not_called()
assert hass.data[DOMAIN][DATA_CONTROLLER_MANAGER].controller == controller
assert hass.data[DOMAIN][MEDIA_PLAYER_DOMAIN] == controller.players
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].favorites == favorites
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].inputs == input_sources
async def test_async_setup_entry_not_signed_in_loads_platforms( async def test_async_setup_entry_not_signed_in_loads_platforms(
@ -121,10 +113,6 @@ async def test_async_setup_entry_not_signed_in_loads_platforms(
assert controller.get_favorites.call_count == 0 assert controller.get_favorites.call_count == 0
assert controller.get_input_sources.call_count == 1 assert controller.get_input_sources.call_count == 1
controller.disconnect.assert_not_called() controller.disconnect.assert_not_called()
assert hass.data[DOMAIN][DATA_CONTROLLER_MANAGER].controller == controller
assert hass.data[DOMAIN][MEDIA_PLAYER_DOMAIN] == controller.players
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].favorites == {}
assert hass.data[DOMAIN][DATA_SOURCE_MANAGER].inputs == input_sources
assert ( assert (
"127.0.0.1 is not logged in to a HEOS account and will be unable to retrieve " "127.0.0.1 is not logged in to a HEOS account and will be unable to retrieve "
"HEOS favorites: Use the 'heos.sign_in' service to sign-in to a HEOS account" "HEOS favorites: Use the 'heos.sign_in' service to sign-in to a HEOS account"
@ -163,7 +151,8 @@ async def test_async_setup_entry_player_failure(
async def test_unload_entry(hass: HomeAssistant, config_entry, controller) -> None: async def test_unload_entry(hass: HomeAssistant, config_entry, controller) -> None:
"""Test entries are unloaded correctly.""" """Test entries are unloaded correctly."""
controller_manager = Mock(ControllerManager) controller_manager = Mock(ControllerManager)
hass.data[DOMAIN] = {DATA_CONTROLLER_MANAGER: controller_manager} config_entry.runtime_data = HeosRuntimeData(controller_manager, None, None, {})
with patch.object( with patch.object(
hass.config_entries, "async_forward_entry_unload", return_value=True hass.config_entries, "async_forward_entry_unload", return_value=True
) as unload: ) as unload:
@ -186,7 +175,7 @@ async def test_update_sources_retry(
assert await async_setup_component(hass, DOMAIN, config) assert await async_setup_component(hass, DOMAIN, config)
controller.get_favorites.reset_mock() controller.get_favorites.reset_mock()
controller.get_input_sources.reset_mock() controller.get_input_sources.reset_mock()
source_manager = hass.data[DOMAIN][DATA_SOURCE_MANAGER] source_manager = config_entry.runtime_data.source_manager
source_manager.retry_delay = 0 source_manager.retry_delay = 0
source_manager.max_retry_attempts = 1 source_manager.max_retry_attempts = 1
controller.get_favorites.side_effect = CommandFailedError("Test", "test", 0) controller.get_favorites.side_effect = CommandFailedError("Test", "test", 0)

View File

@ -8,11 +8,7 @@ from pyheos.error import HeosError
import pytest import pytest
from homeassistant.components.heos import media_player from homeassistant.components.heos import media_player
from homeassistant.components.heos.const import ( from homeassistant.components.heos.const import DOMAIN, SIGNAL_HEOS_UPDATED
DATA_SOURCE_MANAGER,
DOMAIN,
SIGNAL_HEOS_UPDATED,
)
from homeassistant.components.media_player import ( from homeassistant.components.media_player import (
ATTR_GROUP_MEMBERS, ATTR_GROUP_MEMBERS,
ATTR_INPUT_SOURCE, ATTR_INPUT_SOURCE,
@ -106,7 +102,7 @@ async def test_state_attributes(
assert ATTR_INPUT_SOURCE not in state.attributes assert ATTR_INPUT_SOURCE not in state.attributes
assert ( assert (
state.attributes[ATTR_INPUT_SOURCE_LIST] state.attributes[ATTR_INPUT_SOURCE_LIST]
== hass.data[DOMAIN][DATA_SOURCE_MANAGER].source_list == config_entry.runtime_data.source_manager.source_list
) )
@ -219,7 +215,7 @@ async def test_updates_from_sources_updated(
const.SIGNAL_CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {} const.SIGNAL_CONTROLLER_EVENT, const.EVENT_SOURCES_CHANGED, {}
) )
await event.wait() await event.wait()
source_list = hass.data[DOMAIN][DATA_SOURCE_MANAGER].source_list source_list = config_entry.runtime_data.source_manager.source_list
assert len(source_list) == 2 assert len(source_list) == 2
state = hass.states.get("media_player.test_player") state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == source_list assert state.attributes[ATTR_INPUT_SOURCE_LIST] == source_list
@ -318,7 +314,7 @@ async def test_updates_from_user_changed(
const.SIGNAL_CONTROLLER_EVENT, const.EVENT_USER_CHANGED, None const.SIGNAL_CONTROLLER_EVENT, const.EVENT_USER_CHANGED, None
) )
await event.wait() await event.wait()
source_list = hass.data[DOMAIN][DATA_SOURCE_MANAGER].source_list source_list = config_entry.runtime_data.source_manager.source_list
assert len(source_list) == 1 assert len(source_list) == 1
state = hass.states.get("media_player.test_player") state = hass.states.get("media_player.test_player")
assert state.attributes[ATTR_INPUT_SOURCE_LIST] == source_list assert state.attributes[ATTR_INPUT_SOURCE_LIST] == source_list