diff --git a/homeassistant/components/heos/__init__.py b/homeassistant/components/heos/__init__.py index e8d875d283c..8ca2040fd2f 100644 --- a/homeassistant/components/heos/__init__.py +++ b/homeassistant/components/heos/__init__.py @@ -8,20 +8,13 @@ from datetime import timedelta import logging from typing import Any -from pyheos import ( - Heos, - HeosError, - HeosPlayer, - PlayerUpdateResult, - SignalHeosEvent, - const as heos_const, -) +from pyheos import Heos, HeosError, HeosPlayer, const as heos_const from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError, ServiceValidationError -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers import device_registry as dr import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -54,7 +47,6 @@ class HeosRuntimeData: """Runtime data and coordinators for HEOS config entries.""" coordinator: HeosCoordinator - controller_manager: ControllerManager group_manager: GroupManager source_manager: SourceManager players: dict[int, HeosPlayer] @@ -95,16 +87,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> bool favorites = coordinator.favorites inputs = coordinator.inputs - controller_manager = ControllerManager(hass, controller) - await controller_manager.connect_listeners() - source_manager = SourceManager(favorites, inputs) source_manager.connect_update(hass, controller) group_manager = GroupManager(hass, controller, players) entry.runtime_data = HeosRuntimeData( - coordinator, controller_manager, group_manager, source_manager, players + coordinator, group_manager, source_manager, players ) group_manager.connect_update() @@ -120,85 +109,6 @@ async def async_unload_entry(hass: HomeAssistant, entry: HeosConfigEntry) -> boo return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) -class ControllerManager: - """Class that manages events of the controller.""" - - def __init__(self, hass: HomeAssistant, controller: Heos) -> None: - """Init the controller manager.""" - self._hass = hass - self._device_registry: dr.DeviceRegistry | None = None - self._entity_registry: er.EntityRegistry | None = None - self.controller = controller - - async def connect_listeners(self): - """Subscribe to events of interest.""" - self._device_registry = dr.async_get(self._hass) - self._entity_registry = er.async_get(self._hass) - - # Handle controller events - self.controller.add_on_controller_event(self._controller_event) - - # Handle connection-related events - self.controller.add_on_heos_event(self._heos_event) - - async def disconnect(self): - """Disconnect subscriptions.""" - self.controller.dispatcher.disconnect_all() - await self.controller.disconnect() - - async def _controller_event( - self, event: str, data: PlayerUpdateResult | None - ) -> None: - """Handle controller event.""" - if event == heos_const.EVENT_PLAYERS_CHANGED: - assert data is not None - self.update_ids(data.updated_player_ids) - # Update players - async_dispatcher_send(self._hass, SIGNAL_HEOS_UPDATED) - - async def _heos_event(self, event): - """Handle connection event.""" - if event == SignalHeosEvent.CONNECTED: - try: - # Retrieve latest players and refresh status - data = await self.controller.load_players() - self.update_ids(data.updated_player_ids) - except HeosError as ex: - _LOGGER.error("Unable to refresh players: %s", ex) - # Update players - _LOGGER.debug("HEOS Controller event called, calling dispatcher") - async_dispatcher_send(self._hass, SIGNAL_HEOS_UPDATED) - - def update_ids(self, mapped_ids: dict[int, int]): - """Update the IDs in the device and entity registry.""" - # mapped_ids contains the mapped IDs (new:old) - for old_id, new_id in mapped_ids.items(): - # update device registry - assert self._device_registry is not None - entry = self._device_registry.async_get_device( - identifiers={(DOMAIN, str(old_id))} - ) - new_identifiers = {(DOMAIN, str(new_id))} - if entry: - self._device_registry.async_update_device( - entry.id, - new_identifiers=new_identifiers, - ) - _LOGGER.debug( - "Updated device %s identifiers to %s", entry.id, new_identifiers - ) - # update entity registry - assert self._entity_registry is not None - entity_id = self._entity_registry.async_get_entity_id( - Platform.MEDIA_PLAYER, DOMAIN, str(old_id) - ) - if entity_id: - self._entity_registry.async_update_entity( - entity_id, new_unique_id=str(new_id) - ) - _LOGGER.debug("Updated entity %s unique id to %s", entity_id, new_id) - - class GroupManager: """Class that manages HEOS groups.""" diff --git a/homeassistant/components/heos/config_flow.py b/homeassistant/components/heos/config_flow.py index 86d5123bccf..335b64977b8 100644 --- a/homeassistant/components/heos/config_flow.py +++ b/homeassistant/components/heos/config_flow.py @@ -2,7 +2,7 @@ from collections.abc import Mapping import logging -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse from pyheos import CommandAuthenticationError, Heos, HeosError, HeosOptions @@ -10,6 +10,7 @@ import voluptuous as vol from homeassistant.config_entries import ( ConfigEntry, + ConfigEntryState, ConfigFlow, ConfigFlowResult, OptionsFlow, @@ -22,6 +23,7 @@ from homeassistant.helpers.service_info.ssdp import ( SsdpServiceInfo, ) +from . import HeosConfigEntry from .const import DOMAIN _LOGGER = logging.getLogger(__name__) @@ -183,10 +185,12 @@ class HeosFlowHandler(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Validate account credentials and update options.""" errors: dict[str, str] = {} - entry = self._get_reauth_entry() + entry: HeosConfigEntry = self._get_reauth_entry() if user_input is not None: - heos = cast(Heos, entry.runtime_data.controller_manager.controller) - if await _validate_auth(user_input, heos, errors): + assert entry.state is ConfigEntryState.LOADED + if await _validate_auth( + user_input, entry.runtime_data.coordinator.heos, errors + ): return self.async_update_reload_and_abort(entry, options=user_input) return self.async_show_form( @@ -207,10 +211,10 @@ class HeosOptionsFlowHandler(OptionsFlow): """Manage the options.""" errors: dict[str, str] = {} if user_input is not None: - heos = cast( - Heos, self.config_entry.runtime_data.controller_manager.controller - ) - if await _validate_auth(user_input, heos, errors): + entry: HeosConfigEntry = self.config_entry + if await _validate_auth( + user_input, entry.runtime_data.coordinator.heos, errors + ): return self.async_create_entry(data=user_input) return self.async_show_form( diff --git a/homeassistant/components/heos/coordinator.py b/homeassistant/components/heos/coordinator.py index 8ccae0f63b6..9a59b54f6a3 100644 --- a/homeassistant/components/heos/coordinator.py +++ b/homeassistant/components/heos/coordinator.py @@ -7,12 +7,21 @@ entities to update. Entities subscribe to entity-specific updates within the ent import logging -from pyheos import Credentials, Heos, HeosError, HeosOptions, MediaItem +from pyheos import ( + Credentials, + Heos, + HeosError, + HeosOptions, + MediaItem, + PlayerUpdateResult, + const, +) from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME +from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME, Platform from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from . import DOMAIN @@ -66,6 +75,10 @@ class HeosCoordinator(DataUpdateCoordinator[None]): ) # Retrieve initial data await self._async_update_sources() + # Attach event callbacks + self.heos.add_on_disconnected(self._async_on_disconnected) + self.heos.add_on_connected(self._async_on_reconnected) + self.heos.add_on_controller_event(self._async_on_controller_event) async def async_shutdown(self) -> None: """Disconnect all callbacks and disconnect from the device.""" @@ -78,6 +91,58 @@ class HeosCoordinator(DataUpdateCoordinator[None]): assert self.config_entry is not None self.config_entry.async_start_reauth(self.hass) + async def _async_on_disconnected(self) -> None: + """Handle when disconnected so entities are marked unavailable.""" + _LOGGER.warning("Connection to HEOS host %s lost", self.host) + self.async_update_listeners() + + async def _async_on_reconnected(self) -> None: + """Handle when reconnected so resources are updated and entities marked available.""" + await self._async_update_players() + _LOGGER.warning("Successfully reconnected to HEOS host %s", self.host) + self.async_update_listeners() + + async def _async_on_controller_event( + self, event: str, data: PlayerUpdateResult | None + ) -> None: + """Handle a controller event, such as players or groups changed.""" + if event == const.EVENT_PLAYERS_CHANGED: + assert data is not None + if data.updated_player_ids: + self._async_update_player_ids(data.updated_player_ids) + self.async_update_listeners() + + def _async_update_player_ids(self, updated_player_ids: dict[int, int]) -> None: + """Update the IDs in the device and entity registry.""" + device_registry = dr.async_get(self.hass) + entity_registry = er.async_get(self.hass) + # updated_player_ids contains the mapped IDs in format old:new + for old_id, new_id in updated_player_ids.items(): + # update device registry + entry = device_registry.async_get_device( + identifiers={(DOMAIN, str(old_id))} + ) + if entry: + new_identifiers = entry.identifiers.copy() + new_identifiers.remove((DOMAIN, str(old_id))) + new_identifiers.add((DOMAIN, str(new_id))) + device_registry.async_update_device( + entry.id, + new_identifiers=new_identifiers, + ) + _LOGGER.debug( + "Updated device %s identifiers to %s", entry.id, new_identifiers + ) + # update entity registry + entity_id = entity_registry.async_get_entity_id( + Platform.MEDIA_PLAYER, DOMAIN, str(old_id) + ) + if entity_id: + entity_registry.async_update_entity( + entity_id, new_unique_id=str(new_id) + ) + _LOGGER.debug("Updated entity %s unique id to %s", entity_id, new_id) + async def _async_update_sources(self) -> None: """Build source list for entities.""" # Get favorites only if reportedly signed in. @@ -91,3 +156,14 @@ class HeosCoordinator(DataUpdateCoordinator[None]): self.inputs = await self.heos.get_input_sources() except HeosError as error: _LOGGER.error("Unable to retrieve input sources: %s", error) + + async def _async_update_players(self) -> None: + """Update players after reconnection.""" + try: + player_updates = await self.heos.load_players() + except HeosError as error: + _LOGGER.error("Unable to refresh players: %s", error) + return + # After reconnecting, player_id may have changed + if player_updates.updated_player_ids: + self._async_update_player_ids(player_updates.updated_player_ids) diff --git a/homeassistant/components/heos/quality_scale.yaml b/homeassistant/components/heos/quality_scale.yaml index 81162ab9b97..2cd0ccaf567 100644 --- a/homeassistant/components/heos/quality_scale.yaml +++ b/homeassistant/components/heos/quality_scale.yaml @@ -29,10 +29,7 @@ rules: docs-installation-parameters: done entity-unavailable: done integration-owner: done - log-when-unavailable: - status: todo - comment: | - The integration currently spams the logs until reconnected + log-when-unavailable: done parallel-updates: done reauthentication-flow: done test-coverage: diff --git a/homeassistant/components/heos/services.py b/homeassistant/components/heos/services.py index a780c26fca6..00be409869a 100644 --- a/homeassistant/components/heos/services.py +++ b/homeassistant/components/heos/services.py @@ -64,7 +64,7 @@ def _get_controller(hass: HomeAssistant) -> Heos: raise HomeAssistantError( translation_domain=DOMAIN, translation_key="integration_not_loaded" ) - return entry.runtime_data.controller_manager.controller + return entry.runtime_data.coordinator.heos async def _sign_in_handler(service: ServiceCall) -> None: diff --git a/tests/components/heos/test_config_flow.py b/tests/components/heos/test_config_flow.py index 2f01e70e2d1..39ede354496 100644 --- a/tests/components/heos/test_config_flow.py +++ b/tests/components/heos/test_config_flow.py @@ -4,7 +4,7 @@ from pyheos import CommandAuthenticationError, CommandFailedError, Heos, HeosErr import pytest from homeassistant.components.heos.const import DOMAIN -from homeassistant.config_entries import SOURCE_SSDP, SOURCE_USER +from homeassistant.config_entries import SOURCE_SSDP, SOURCE_USER, ConfigEntryState from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -358,6 +358,7 @@ async def test_reauth_signs_in_aborts( config_entry.add_to_hass(hass) assert await hass.config_entries.async_setup(config_entry.entry_id) result = await config_entry.start_reauth_flow(hass) + assert config_entry.state is ConfigEntryState.LOADED assert result["step_id"] == "reauth_confirm" assert result["errors"] == {} @@ -396,6 +397,7 @@ async def test_reauth_signs_out( config_entry.add_to_hass(hass) assert await hass.config_entries.async_setup(config_entry.entry_id) result = await config_entry.start_reauth_flow(hass) + assert config_entry.state is ConfigEntryState.LOADED assert result["step_id"] == "reauth_confirm" assert result["errors"] == {} @@ -447,6 +449,7 @@ async def test_reauth_flow_missing_one_param_recovers( # Start the options flow. Entry has not current options. result = await config_entry.start_reauth_flow(hass) + assert config_entry.state is ConfigEntryState.LOADED assert result["step_id"] == "reauth_confirm" assert result["errors"] == {} assert result["type"] is FlowResultType.FORM diff --git a/tests/components/heos/test_media_player.py b/tests/components/heos/test_media_player.py index 00082c77f0f..539b4584502 100644 --- a/tests/components/heos/test_media_player.py +++ b/tests/components/heos/test_media_player.py @@ -172,6 +172,36 @@ async def test_updates_from_connection_event( assert "Unable to refresh players" in caplog.text +async def test_updates_from_connection_event_new_player_ids( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + device_registry: dr.DeviceRegistry, + config_entry: MockConfigEntry, + controller: Heos, + change_data_mapped_ids: PlayerUpdateResult, +) -> None: + """Test player ids changed after reconnection updates ids.""" + config_entry.add_to_hass(hass) + assert await hass.config_entries.async_setup(config_entry.entry_id) + + # Assert current IDs + assert device_registry.async_get_device(identifiers={(DOMAIN, "1")}) + assert entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "1") + + # Send event which will result in updated IDs. + controller.load_players.return_value = change_data_mapped_ids + await controller.dispatcher.wait_send( + SignalType.HEOS_EVENT, SignalHeosEvent.CONNECTED + ) + await hass.async_block_till_done() + + # Assert updated IDs and previous don't exist + assert not device_registry.async_get_device(identifiers={(DOMAIN, "1")}) + assert device_registry.async_get_device(identifiers={(DOMAIN, "101")}) + assert not entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "1") + assert entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "101") + + async def test_updates_from_sources_updated( hass: HomeAssistant, config_entry: MockConfigEntry,