diff --git a/homeassistant/components/heos/coordinator.py b/homeassistant/components/heos/coordinator.py index 94aa4ad0ab5..0303d150794 100644 --- a/homeassistant/components/heos/coordinator.py +++ b/homeassistant/components/heos/coordinator.py @@ -16,6 +16,7 @@ from pyheos import ( HeosError, HeosNowPlayingMedia, HeosOptions, + HeosPlayer, MediaItem, MediaType, PlayerUpdateResult, @@ -58,6 +59,7 @@ class HeosCoordinator(DataUpdateCoordinator[None]): credentials=credentials, ) ) + self._platform_callbacks: list[Callable[[Sequence[HeosPlayer]], None]] = [] self._update_sources_pending: bool = False self._source_list: list[str] = [] self._favorites: dict[int, MediaItem] = {} @@ -124,6 +126,27 @@ class HeosCoordinator(DataUpdateCoordinator[None]): self.async_update_listeners() return remove_listener + def async_add_platform_callback( + self, add_entities_callback: Callable[[Sequence[HeosPlayer]], None] + ) -> None: + """Add a callback to add entities for a platform.""" + self._platform_callbacks.append(add_entities_callback) + + def _async_handle_player_update_result( + self, update_result: PlayerUpdateResult + ) -> None: + """Handle a player update result.""" + if update_result.added_player_ids and self._platform_callbacks: + new_players = [ + self.heos.players[player_id] + for player_id in update_result.added_player_ids + ] + for add_entities_callback in self._platform_callbacks: + add_entities_callback(new_players) + + if update_result.updated_player_ids: + self._async_update_player_ids(update_result.updated_player_ids) + async def _async_on_auth_failure(self) -> None: """Handle when the user credentials are no longer valid.""" assert self.config_entry is not None @@ -147,8 +170,7 @@ class HeosCoordinator(DataUpdateCoordinator[None]): """Handle a controller event, such as players or groups changed.""" if event == const.EVENT_PLAYERS_CHANGED: assert data is not None - if data.updated_player_ids: - self._async_update_player_ids(data.updated_player_ids) + self._async_handle_player_update_result(data) elif ( event in (const.EVENT_SOURCES_CHANGED, const.EVENT_USER_CHANGED) and not self._update_sources_pending @@ -242,9 +264,7 @@ class HeosCoordinator(DataUpdateCoordinator[None]): except HeosError as error: _LOGGER.error("Unable to refresh players: %s", error) return - # After reconnecting, player_id may have changed - if player_updates.updated_player_ids: - self._async_update_player_ids(player_updates.updated_player_ids) + self._async_handle_player_update_result(player_updates) @callback def async_get_source_list(self) -> list[str]: diff --git a/homeassistant/components/heos/media_player.py b/homeassistant/components/heos/media_player.py index 4dbaead67a7..b9aa05810e5 100644 --- a/homeassistant/components/heos/media_player.py +++ b/homeassistant/components/heos/media_player.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Awaitable, Callable, Coroutine, Sequence from datetime import datetime from functools import reduce, wraps from operator import ior @@ -93,11 +93,16 @@ async def async_setup_entry( async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Add media players for a config entry.""" - devices = [ - HeosMediaPlayer(entry.runtime_data, player) - for player in entry.runtime_data.heos.players.values() - ] - async_add_entities(devices) + + def add_entities_callback(players: Sequence[HeosPlayer]) -> None: + """Add entities for each player.""" + async_add_entities( + [HeosMediaPlayer(entry.runtime_data, player) for player in players] + ) + + coordinator = entry.runtime_data + coordinator.async_add_platform_callback(add_entities_callback) + add_entities_callback(list(coordinator.heos.players.values())) type _FuncType[**_P] = Callable[_P, Awaitable[Any]] diff --git a/homeassistant/components/heos/quality_scale.yaml b/homeassistant/components/heos/quality_scale.yaml index a1220366fa3..a08e2dca544 100644 --- a/homeassistant/components/heos/quality_scale.yaml +++ b/homeassistant/components/heos/quality_scale.yaml @@ -49,7 +49,7 @@ rules: docs-supported-functions: done docs-troubleshooting: done docs-use-cases: done - dynamic-devices: todo + dynamic-devices: done entity-category: done entity-device-class: done entity-disabled-by-default: done diff --git a/tests/components/heos/conftest.py b/tests/components/heos/conftest.py index 39937a8355f..7bed05a0289 100644 --- a/tests/components/heos/conftest.py +++ b/tests/components/heos/conftest.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Callable, Iterator from unittest.mock import Mock, patch from pyheos import ( @@ -130,16 +130,17 @@ def system_info_fixture() -> HeosSystem: ) -@pytest.fixture(name="players") -def players_fixture() -> dict[int, HeosPlayer]: - """Create two mock HeosPlayers.""" - players = {} - for i in (1, 2): - player = HeosPlayer( - player_id=i, +@pytest.fixture(name="player_factory") +def player_factory_fixture() -> Callable[[int, str, str], HeosPlayer]: + """Return a method that creates players.""" + + def factory(player_id: int, name: str, model: str) -> HeosPlayer: + """Create a player.""" + return HeosPlayer( + player_id=player_id, group_id=999, - name="Test Player" if i == 1 else f"Test Player {i}", - model="HEOS Drive HS2" if i == 1 else "Speaker", + name=name, + model=model, serial="123456", version="1.0.0", supported_version=True, @@ -147,26 +148,37 @@ def players_fixture() -> dict[int, HeosPlayer]: is_muted=False, available=True, state=PlayState.STOP, - ip_address=f"127.0.0.{i}", + ip_address=f"127.0.0.{player_id}", network=NetworkType.WIRED, shuffle=False, repeat=RepeatType.OFF, volume=25, + now_playing_media=HeosNowPlayingMedia( + type=MediaType.STATION, + song="Song", + station="Station Name", + album="Album", + artist="Artist", + image_url="http://", + album_id="1", + media_id="1", + queue_id=1, + source_id=10, + ), ) - player.now_playing_media = HeosNowPlayingMedia( - type=MediaType.STATION, - song="Song", - station="Station Name", - album="Album", - artist="Artist", - image_url="http://", - album_id="1", - media_id="1", - queue_id=1, - source_id=10, - ) - players[player.player_id] = player - return players + + return factory + + +@pytest.fixture(name="players") +def players_fixture( + player_factory: Callable[[int, str, str], HeosPlayer], +) -> dict[int, HeosPlayer]: + """Create two mock HeosPlayers.""" + return { + 1: player_factory(1, "Test Player", "HEOS Drive HS2"), + 2: player_factory(2, "Test Player 2", "Speaker"), + } @pytest.fixture(name="group") diff --git a/tests/components/heos/test_init.py b/tests/components/heos/test_init.py index 60bc2a72e51..87cc8dd7dde 100644 --- a/tests/components/heos/test_init.py +++ b/tests/components/heos/test_init.py @@ -1,16 +1,26 @@ """Tests for the init module.""" +from collections.abc import Callable from typing import cast from unittest.mock import Mock -from pyheos import HeosError, HeosOptions, SignalHeosEvent, SignalType +from pyheos import ( + HeosError, + HeosOptions, + HeosPlayer, + PlayerUpdateResult, + SignalHeosEvent, + SignalType, + const, +) import pytest from homeassistant.components.heos.const import DOMAIN +from homeassistant.components.media_player import DOMAIN as MEDIA_PLAYER_DOMAIN from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.core import HomeAssistant -from homeassistant.helpers import device_registry as dr +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component from . import MockHeos @@ -255,3 +265,64 @@ async def test_remove_config_entry_device( ws_client = await hass_ws_client(hass) response = await ws_client.remove_device(device_entry.id, config_entry.entry_id) assert response["success"] == expected_result + + +async def test_reconnected_new_entities_created( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + config_entry: MockConfigEntry, + controller: MockHeos, + player_factory: Callable[[int, str, str], HeosPlayer], +) -> None: + """Test new entities are created for new players after reconnecting.""" + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + + # Assert initial entity doesn't exist + assert not entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "3") + + # Create player + players = controller.players.copy() + players[3] = player_factory(3, "Test Player 3", "HEOS Link") + controller.mock_set_players(players) + controller.load_players.return_value = PlayerUpdateResult([3], [], {}) + + # Simulate reconnection + await controller.dispatcher.wait_send( + SignalType.HEOS_EVENT, SignalHeosEvent.CONNECTED + ) + await hass.async_block_till_done() + + # Assert new entity created + assert entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "3") + + +async def test_players_changed_new_entities_created( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + config_entry: MockConfigEntry, + controller: MockHeos, + player_factory: Callable[[int, str, str], HeosPlayer], +) -> None: + """Test new entities are created for new players on change event.""" + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + + # Assert initial entity doesn't exist + assert not entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "3") + + # Create player + players = controller.players.copy() + players[3] = player_factory(3, "Test Player 3", "HEOS Link") + controller.mock_set_players(players) + + # Simulate players changed event + await controller.dispatcher.wait_send( + SignalType.CONTROLLER_EVENT, + const.EVENT_PLAYERS_CHANGED, + PlayerUpdateResult([3], [], {}), + ) + await hass.async_block_till_done() + + # Assert new entity created + assert entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "3")