From 39b4e890a0433f3d49c21e1c1dab624323dbcc24 Mon Sep 17 00:00:00 2001 From: epenet <6771947+epenet@users.noreply.github.com> Date: Wed, 22 May 2024 09:20:05 +0200 Subject: [PATCH] Add coordinator to SamsungTV (#117863) * Introduce samsungtv coordinator * Adjust * Adjust media_player * Remove remote * Adjust * Fix mypy * Adjust * Use coordinator.async_refresh --- .../components/samsungtv/__init__.py | 7 +- .../components/samsungtv/coordinator.py | 50 ++++++++++++++ .../components/samsungtv/diagnostics.py | 4 +- homeassistant/components/samsungtv/entity.py | 12 ++-- homeassistant/components/samsungtv/helpers.py | 2 +- .../components/samsungtv/media_player.py | 62 ++++++++--------- homeassistant/components/samsungtv/remote.py | 4 +- .../components/samsungtv/test_media_player.py | 66 +++++-------------- 8 files changed, 115 insertions(+), 92 deletions(-) create mode 100644 homeassistant/components/samsungtv/coordinator.py diff --git a/homeassistant/components/samsungtv/__init__.py b/homeassistant/components/samsungtv/__init__.py index 27d571bc37b..0b2785f77bc 100644 --- a/homeassistant/components/samsungtv/__init__.py +++ b/homeassistant/components/samsungtv/__init__.py @@ -49,12 +49,13 @@ from .const import ( UPNP_SVC_MAIN_TV_AGENT, UPNP_SVC_RENDERING_CONTROL, ) +from .coordinator import SamsungTVDataUpdateCoordinator PLATFORMS = [Platform.MEDIA_PLAYER, Platform.REMOTE] CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False) -SamsungTVConfigEntry = ConfigEntry[SamsungTVBridge] +SamsungTVConfigEntry = ConfigEntry[SamsungTVDataUpdateCoordinator] @callback @@ -179,7 +180,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: SamsungTVConfigEntry) -> entry.async_on_unload(debounced_reloader.async_shutdown) entry.async_on_unload(entry.add_update_listener(debounced_reloader.async_call)) - entry.runtime_data = bridge + coordinator = SamsungTVDataUpdateCoordinator(hass, bridge) + await coordinator.async_config_entry_first_refresh() + entry.runtime_data = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True diff --git a/homeassistant/components/samsungtv/coordinator.py b/homeassistant/components/samsungtv/coordinator.py new file mode 100644 index 00000000000..92d8dc8fa84 --- /dev/null +++ b/homeassistant/components/samsungtv/coordinator.py @@ -0,0 +1,50 @@ +"""Coordinator for the SamsungTV integration.""" + +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from datetime import timedelta +from typing import Any + +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator + +from .bridge import SamsungTVBridge +from .const import DOMAIN, LOGGER + +SCAN_INTERVAL = 10 + + +class SamsungTVDataUpdateCoordinator(DataUpdateCoordinator[None]): + """Coordinator for the SamsungTV integration.""" + + config_entry: ConfigEntry + + def __init__(self, hass: HomeAssistant, bridge: SamsungTVBridge) -> None: + """Initialize the coordinator.""" + super().__init__( + hass, + LOGGER, + name=DOMAIN, + update_interval=timedelta(seconds=SCAN_INTERVAL), + ) + + self.bridge = bridge + self.is_on: bool | None = False + self.async_extra_update: Callable[[], Coroutine[Any, Any, None]] | None = None + + async def _async_update_data(self) -> None: + """Fetch data from SamsungTV bridge.""" + if self.bridge.auth_failed or self.hass.is_stopping: + return + old_state = self.is_on + if self.bridge.power_off_in_progress: + self.is_on = False + else: + self.is_on = await self.bridge.async_is_on() + if self.is_on != old_state: + LOGGER.debug("TV %s state updated to %s", self.bridge.host, self.is_on) + + if self.async_extra_update: + await self.async_extra_update() diff --git a/homeassistant/components/samsungtv/diagnostics.py b/homeassistant/components/samsungtv/diagnostics.py index a0da9a59261..ebca8d2543b 100644 --- a/homeassistant/components/samsungtv/diagnostics.py +++ b/homeassistant/components/samsungtv/diagnostics.py @@ -18,8 +18,8 @@ async def async_get_config_entry_diagnostics( hass: HomeAssistant, entry: SamsungTVConfigEntry ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - bridge = entry.runtime_data + coordinator = entry.runtime_data return { "entry": async_redact_data(entry.as_dict(), TO_REDACT), - "device_info": await bridge.async_device_info(), + "device_info": await coordinator.bridge.async_device_info(), } diff --git a/homeassistant/components/samsungtv/entity.py b/homeassistant/components/samsungtv/entity.py index fc1c5bf7715..e2c1fb66bcc 100644 --- a/homeassistant/components/samsungtv/entity.py +++ b/homeassistant/components/samsungtv/entity.py @@ -4,7 +4,6 @@ from __future__ import annotations from wakeonlan import send_magic_packet -from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( ATTR_CONNECTIONS, ATTR_IDENTIFIERS, @@ -17,20 +16,23 @@ from homeassistant.helpers import device_registry as dr from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.entity import Entity from homeassistant.helpers.trigger import PluggableAction +from homeassistant.helpers.update_coordinator import CoordinatorEntity -from .bridge import SamsungTVBridge from .const import CONF_MANUFACTURER, DOMAIN +from .coordinator import SamsungTVDataUpdateCoordinator from .triggers.turn_on import async_get_turn_on_trigger -class SamsungTVEntity(Entity): +class SamsungTVEntity(CoordinatorEntity[SamsungTVDataUpdateCoordinator], Entity): """Defines a base SamsungTV entity.""" _attr_has_entity_name = True - def __init__(self, *, bridge: SamsungTVBridge, config_entry: ConfigEntry) -> None: + def __init__(self, *, coordinator: SamsungTVDataUpdateCoordinator) -> None: """Initialize the SamsungTV entity.""" - self._bridge = bridge + super().__init__(coordinator) + self._bridge = coordinator.bridge + config_entry = coordinator.config_entry self._mac: str | None = config_entry.data.get(CONF_MAC) self._host: str | None = config_entry.data.get(CONF_HOST) # Fallback for legacy models that doesn't have a API to retrieve MAC or SerialNumber diff --git a/homeassistant/components/samsungtv/helpers.py b/homeassistant/components/samsungtv/helpers.py index 4ee881a3631..4e8dd00d486 100644 --- a/homeassistant/components/samsungtv/helpers.py +++ b/homeassistant/components/samsungtv/helpers.py @@ -58,7 +58,7 @@ def async_get_client_by_device_entry( for config_entry_id in device.config_entries: entry = hass.config_entries.async_get_entry(config_entry_id) if entry and entry.domain == DOMAIN and entry.state is ConfigEntryState.LOADED: - return entry.runtime_data + return entry.runtime_data.bridge raise ValueError( f"Device {device.id} is not from an existing {DOMAIN} config entry" diff --git a/homeassistant/components/samsungtv/media_player.py b/homeassistant/components/samsungtv/media_player.py index 12952f72d2e..6b984130f70 100644 --- a/homeassistant/components/samsungtv/media_player.py +++ b/homeassistant/components/samsungtv/media_player.py @@ -28,7 +28,6 @@ from homeassistant.components.media_player import ( MediaPlayerState, MediaType, ) -from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.helpers.aiohttp_client import async_get_clientsession @@ -36,8 +35,9 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.util.async_ import create_eager_task from . import SamsungTVConfigEntry -from .bridge import SamsungTVBridge, SamsungTVWSBridge +from .bridge import SamsungTVWSBridge from .const import CONF_SSDP_RENDERING_CONTROL_LOCATION, LOGGER +from .coordinator import SamsungTVDataUpdateCoordinator from .entity import SamsungTVEntity SOURCES = {"TV": "KEY_TV", "HDMI": "KEY_HDMI"} @@ -67,8 +67,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Samsung TV from a config entry.""" - bridge = entry.runtime_data - async_add_entities([SamsungTVDevice(bridge, entry)], True) + coordinator = entry.runtime_data + async_add_entities([SamsungTVDevice(coordinator)]) class SamsungTVDevice(SamsungTVEntity, MediaPlayerEntity): @@ -78,16 +78,11 @@ class SamsungTVDevice(SamsungTVEntity, MediaPlayerEntity): _attr_name = None _attr_device_class = MediaPlayerDeviceClass.TV - def __init__( - self, - bridge: SamsungTVBridge, - config_entry: ConfigEntry, - ) -> None: + def __init__(self, coordinator: SamsungTVDataUpdateCoordinator) -> None: """Initialize the Samsung device.""" - super().__init__(bridge=bridge, config_entry=config_entry) - self._config_entry = config_entry - self._ssdp_rendering_control_location: str | None = config_entry.data.get( - CONF_SSDP_RENDERING_CONTROL_LOCATION + super().__init__(coordinator=coordinator) + self._ssdp_rendering_control_location: str | None = ( + coordinator.config_entry.data.get(CONF_SSDP_RENDERING_CONTROL_LOCATION) ) # Assume that the TV is in Play mode self._playing: bool = True @@ -130,27 +125,35 @@ class SamsungTVDevice(SamsungTVEntity, MediaPlayerEntity): self._update_sources() self._app_list_event.set() + async def async_added_to_hass(self) -> None: + """Run when entity about to be added to hass.""" + await super().async_added_to_hass() + await self._async_extra_update() + self.coordinator.async_extra_update = self._async_extra_update + if self.coordinator.is_on: + self._attr_state = MediaPlayerState.ON + self._update_from_upnp() + else: + self._attr_state = MediaPlayerState.OFF + async def async_will_remove_from_hass(self) -> None: """Handle removal.""" + self.coordinator.async_extra_update = None await self._async_shutdown_dmr() - async def async_update(self) -> None: - """Update state of device.""" - if self._bridge.auth_failed or self.hass.is_stopping: - return - old_state = self._attr_state - if self._bridge.power_off_in_progress: - self._attr_state = MediaPlayerState.OFF + @callback + def _handle_coordinator_update(self) -> None: + """Handle data update.""" + if self.coordinator.is_on: + self._attr_state = MediaPlayerState.ON + self._update_from_upnp() else: - self._attr_state = ( - MediaPlayerState.ON - if await self._bridge.async_is_on() - else MediaPlayerState.OFF - ) - if self._attr_state != old_state: - LOGGER.debug("TV %s state updated to %s", self._host, self.state) + self._attr_state = MediaPlayerState.OFF + self.async_write_ha_state() - if self._attr_state != MediaPlayerState.ON: + async def _async_extra_update(self) -> None: + """Update state of device.""" + if not self.coordinator.is_on: if self._dmr_device and self._dmr_device.is_subscribed: await self._dmr_device.async_unsubscribe_services() return @@ -168,8 +171,6 @@ class SamsungTVDevice(SamsungTVEntity, MediaPlayerEntity): if startup_tasks: await asyncio.gather(*startup_tasks) - self._update_from_upnp() - @callback def _update_from_upnp(self) -> bool: # Upnp events can affect other attributes that we currently do not track @@ -311,6 +312,7 @@ class SamsungTVDevice(SamsungTVEntity, MediaPlayerEntity): async def async_turn_off(self) -> None: """Turn off media player.""" await self._bridge.async_power_off() + await self.coordinator.async_refresh() async def async_set_volume_level(self, volume: float) -> None: """Set volume level on the media player.""" diff --git a/homeassistant/components/samsungtv/remote.py b/homeassistant/components/samsungtv/remote.py index 6c6bc6774d3..29681f96ab7 100644 --- a/homeassistant/components/samsungtv/remote.py +++ b/homeassistant/components/samsungtv/remote.py @@ -21,8 +21,8 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Set up the Samsung TV from a config entry.""" - bridge = entry.runtime_data - async_add_entities([SamsungTVRemote(bridge=bridge, config_entry=entry)]) + coordinator = entry.runtime_data + async_add_entities([SamsungTVRemote(coordinator=coordinator)]) class SamsungTVRemote(SamsungTVEntity, RemoteEntity): diff --git a/tests/components/samsungtv/test_media_player.py b/tests/components/samsungtv/test_media_player.py index 639530fa892..4c7ee0e116d 100644 --- a/tests/components/samsungtv/test_media_player.py +++ b/tests/components/samsungtv/test_media_player.py @@ -552,11 +552,9 @@ async def test_send_key(hass: HomeAssistant, remote: Mock) -> None: DOMAIN, SERVICE_VOLUME_UP, {ATTR_ENTITY_ID: ENTITY_ID}, True ) state = hass.states.get(ENTITY_ID) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_VOLUP")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] assert state.state == STATE_ON @@ -583,14 +581,12 @@ async def test_send_key_connection_closed_retry_succeed( DOMAIN, SERVICE_VOLUME_UP, {ATTR_ENTITY_ID: ENTITY_ID}, True ) state = hass.states.get(ENTITY_ID) - # key because of retry two times and update called + # key because of retry two times assert remote.control.call_count == 2 assert remote.control.call_args_list == [ call("KEY_VOLUP"), call("KEY_VOLUP"), ] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] assert state.state == STATE_ON @@ -914,11 +910,9 @@ async def test_volume_up(hass: HomeAssistant, remote: Mock) -> None: await hass.services.async_call( DOMAIN, SERVICE_VOLUME_UP, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_VOLUP")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] async def test_volume_down(hass: HomeAssistant, remote: Mock) -> None: @@ -927,11 +921,9 @@ async def test_volume_down(hass: HomeAssistant, remote: Mock) -> None: await hass.services.async_call( DOMAIN, SERVICE_VOLUME_DOWN, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_VOLDOWN")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] async def test_mute_volume(hass: HomeAssistant, remote: Mock) -> None: @@ -943,11 +935,9 @@ async def test_mute_volume(hass: HomeAssistant, remote: Mock) -> None: {ATTR_ENTITY_ID: ENTITY_ID, ATTR_MEDIA_VOLUME_MUTED: True}, True, ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_MUTE")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] async def test_media_play(hass: HomeAssistant, remote: Mock) -> None: @@ -956,20 +946,16 @@ async def test_media_play(hass: HomeAssistant, remote: Mock) -> None: await hass.services.async_call( DOMAIN, SERVICE_MEDIA_PLAY, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_PLAY")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] await hass.services.async_call( DOMAIN, SERVICE_MEDIA_PLAY_PAUSE, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 2 assert remote.control.call_args_list == [call("KEY_PLAY"), call("KEY_PAUSE")] - assert remote.close.call_count == 2 - assert remote.close.call_args_list == [call(), call()] async def test_media_pause(hass: HomeAssistant, remote: Mock) -> None: @@ -978,20 +964,16 @@ async def test_media_pause(hass: HomeAssistant, remote: Mock) -> None: await hass.services.async_call( DOMAIN, SERVICE_MEDIA_PAUSE, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_PAUSE")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] await hass.services.async_call( DOMAIN, SERVICE_MEDIA_PLAY_PAUSE, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 2 assert remote.control.call_args_list == [call("KEY_PAUSE"), call("KEY_PLAY")] - assert remote.close.call_count == 2 - assert remote.close.call_args_list == [call(), call()] async def test_media_next_track(hass: HomeAssistant, remote: Mock) -> None: @@ -1000,11 +982,9 @@ async def test_media_next_track(hass: HomeAssistant, remote: Mock) -> None: await hass.services.async_call( DOMAIN, SERVICE_MEDIA_NEXT_TRACK, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_CHUP")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] async def test_media_previous_track(hass: HomeAssistant, remote: Mock) -> None: @@ -1013,11 +993,9 @@ async def test_media_previous_track(hass: HomeAssistant, remote: Mock) -> None: await hass.services.async_call( DOMAIN, SERVICE_MEDIA_PREVIOUS_TRACK, {ATTR_ENTITY_ID: ENTITY_ID}, True ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_CHDOWN")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] @pytest.mark.usefixtures("remotews", "rest_api") @@ -1074,8 +1052,6 @@ async def test_play_media(hass: HomeAssistant, remote: Mock) -> None: call("KEY_6"), call("KEY_ENTER"), ] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] assert sleep.call_count == 3 @@ -1095,10 +1071,8 @@ async def test_play_media_invalid_type(hass: HomeAssistant) -> None: }, True, ) - # only update called + # control not called assert remote.control.call_count == 0 - assert remote.close.call_count == 0 - assert remote.call_count == 1 async def test_play_media_channel_as_string(hass: HomeAssistant) -> None: @@ -1117,10 +1091,8 @@ async def test_play_media_channel_as_string(hass: HomeAssistant) -> None: }, True, ) - # only update called + # control not called assert remote.control.call_count == 0 - assert remote.close.call_count == 0 - assert remote.call_count == 1 async def test_play_media_channel_as_non_positive(hass: HomeAssistant) -> None: @@ -1138,10 +1110,8 @@ async def test_play_media_channel_as_non_positive(hass: HomeAssistant) -> None: }, True, ) - # only update called + # control not called assert remote.control.call_count == 0 - assert remote.close.call_count == 0 - assert remote.call_count == 1 async def test_select_source(hass: HomeAssistant, remote: Mock) -> None: @@ -1153,11 +1123,9 @@ async def test_select_source(hass: HomeAssistant, remote: Mock) -> None: {ATTR_ENTITY_ID: ENTITY_ID, ATTR_INPUT_SOURCE: "HDMI"}, True, ) - # key and update called + # key called assert remote.control.call_count == 1 assert remote.control.call_args_list == [call("KEY_HDMI")] - assert remote.close.call_count == 1 - assert remote.close.call_args_list == [call()] async def test_select_source_invalid_source(hass: HomeAssistant) -> None: @@ -1171,10 +1139,8 @@ async def test_select_source_invalid_source(hass: HomeAssistant) -> None: {ATTR_ENTITY_ID: ENTITY_ID, ATTR_INPUT_SOURCE: "INVALID"}, True, ) - # only update called + # control not called assert remote.control.call_count == 0 - assert remote.close.call_count == 0 - assert remote.call_count == 1 @pytest.mark.usefixtures("rest_api")