diff --git a/.coveragerc b/.coveragerc index 3e8de435832..41b46796373 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1396,6 +1396,7 @@ omit = homeassistant/components/trafikverket_weatherstation/coordinator.py homeassistant/components/trafikverket_weatherstation/sensor.py homeassistant/components/transmission/__init__.py + homeassistant/components/transmission/coordinator.py homeassistant/components/transmission/sensor.py homeassistant/components/transmission/switch.py homeassistant/components/travisci/sensor.py diff --git a/homeassistant/components/transmission/__init__.py b/homeassistant/components/transmission/__init__.py index 7e02c3d419d..be32c95356d 100644 --- a/homeassistant/components/transmission/__init__.py +++ b/homeassistant/components/transmission/__init__.py @@ -1,8 +1,7 @@ """Support for the Transmission BitTorrent client API.""" from __future__ import annotations -from collections.abc import Callable -from datetime import datetime, timedelta +from datetime import timedelta from functools import partial import logging import re @@ -14,10 +13,9 @@ from transmission_rpc.error import ( TransmissionConnectError, TransmissionError, ) -from transmission_rpc.session import SessionStats import voluptuous as vol -from homeassistant.config_entries import ConfigEntry, ConfigEntryState +from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( CONF_HOST, CONF_ID, @@ -35,33 +33,39 @@ from homeassistant.helpers import ( entity_registry as er, selector, ) -from homeassistant.helpers.dispatcher import dispatcher_send -from homeassistant.helpers.event import async_track_time_interval from .const import ( ATTR_DELETE_DATA, ATTR_TORRENT, CONF_ENTRY_ID, - CONF_LIMIT, - CONF_ORDER, - DATA_UPDATED, DEFAULT_DELETE_DATA, - DEFAULT_LIMIT, - DEFAULT_ORDER, - DEFAULT_SCAN_INTERVAL, DOMAIN, - EVENT_DOWNLOADED_TORRENT, - EVENT_REMOVED_TORRENT, - EVENT_STARTED_TORRENT, SERVICE_ADD_TORRENT, SERVICE_REMOVE_TORRENT, SERVICE_START_TORRENT, SERVICE_STOP_TORRENT, ) +from .coordinator import TransmissionDataUpdateCoordinator from .errors import AuthenticationError, CannotConnect, UnknownError _LOGGER = logging.getLogger(__name__) +PLATFORMS = [Platform.SENSOR, Platform.SWITCH] + +MIGRATION_NAME_TO_KEY = { + # Sensors + "Down Speed": "download", + "Up Speed": "upload", + "Status": "status", + "Active Torrents": "active_torrents", + "Paused Torrents": "paused_torrents", + "Total Torrents": "total_torrents", + "Completed Torrents": "completed_torrents", + "Started Torrents": "started_torrents", + # Switches + "Switch": "on_off", + "Turtle Mode": "turtle_mode", +} SERVICE_BASE_SCHEMA = vol.Schema( { @@ -95,25 +99,6 @@ SERVICE_STOP_TORRENT_SCHEMA = vol.All( ) ) -CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False) - -PLATFORMS = [Platform.SENSOR, Platform.SWITCH] - -MIGRATION_NAME_TO_KEY = { - # Sensors - "Down Speed": "download", - "Up Speed": "upload", - "Status": "status", - "Active Torrents": "active_torrents", - "Paused Torrents": "paused_torrents", - "Total Torrents": "total_torrents", - "Completed Torrents": "completed_torrents", - "Started Torrents": "started_torrents", - # Switches - "Switch": "on_off", - "Turtle Mode": "turtle_mode", -} - async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Set up the Transmission Component.""" @@ -141,24 +126,81 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b except (AuthenticationError, UnknownError) as error: raise ConfigEntryAuthFailed from error - client = TransmissionClient(hass, config_entry, api) - await client.async_setup() - hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client + coordinator = TransmissionDataUpdateCoordinator(hass, config_entry, api) + await hass.async_add_executor_job(coordinator.init_torrent_list) + + await coordinator.async_config_entry_first_refresh() + hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = coordinator await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS) - client.register_services() + config_entry.add_update_listener(async_options_updated) + + async def add_torrent(service: ServiceCall) -> None: + """Add new torrent to download.""" + torrent = service.data[ATTR_TORRENT] + if torrent.startswith( + ("http", "ftp:", "magnet:") + ) or hass.config.is_allowed_path(torrent): + await hass.async_add_executor_job(coordinator.api.add_torrent, torrent) + await coordinator.async_request_refresh() + else: + _LOGGER.warning("Could not add torrent: unsupported type or no permission") + + async def start_torrent(service: ServiceCall) -> None: + """Start torrent.""" + torrent_id = service.data[CONF_ID] + await hass.async_add_executor_job(coordinator.api.start_torrent, torrent_id) + await coordinator.async_request_refresh() + + async def stop_torrent(service: ServiceCall) -> None: + """Stop torrent.""" + torrent_id = service.data[CONF_ID] + await hass.async_add_executor_job(coordinator.api.stop_torrent, torrent_id) + await coordinator.async_request_refresh() + + async def remove_torrent(service: ServiceCall) -> None: + """Remove torrent.""" + torrent_id = service.data[CONF_ID] + delete_data = service.data[ATTR_DELETE_DATA] + await hass.async_add_executor_job( + partial(coordinator.api.remove_torrent, torrent_id, delete_data=delete_data) + ) + await coordinator.async_request_refresh() + + hass.services.async_register( + DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA + ) + + hass.services.async_register( + DOMAIN, + SERVICE_REMOVE_TORRENT, + remove_torrent, + schema=SERVICE_REMOVE_TORRENT_SCHEMA, + ) + + hass.services.async_register( + DOMAIN, + SERVICE_START_TORRENT, + start_torrent, + schema=SERVICE_START_TORRENT_SCHEMA, + ) + + hass.services.async_register( + DOMAIN, + SERVICE_STOP_TORRENT, + stop_torrent, + schema=SERVICE_STOP_TORRENT_SCHEMA, + ) + return True async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: """Unload Transmission Entry from config_entry.""" - client: TransmissionClient = hass.data[DOMAIN].pop(config_entry.entry_id) - if client.unsub_timer: - client.unsub_timer() - - unload_ok = await hass.config_entries.async_unload_platforms( + if unload_ok := await hass.config_entries.async_unload_platforms( config_entry, PLATFORMS - ) + ): + hass.data[DOMAIN].pop(config_entry.entry_id) if not hass.data[DOMAIN]: hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT) @@ -202,286 +244,8 @@ async def get_api( raise UnknownError from error -def _get_client(hass: HomeAssistant, data: dict[str, Any]) -> TransmissionClient | None: - """Return client from integration name or entry_id.""" - if ( - (entry_id := data.get(CONF_ENTRY_ID)) - and (entry := hass.config_entries.async_get_entry(entry_id)) - and entry.state == ConfigEntryState.LOADED - ): - return hass.data[DOMAIN][entry_id] - - return None - - -class TransmissionClient: - """Transmission Client Object.""" - - def __init__( - self, - hass: HomeAssistant, - config_entry: ConfigEntry, - api: transmission_rpc.Client, - ) -> None: - """Initialize the Transmission RPC API.""" - self.hass = hass - self.config_entry = config_entry - self.tm_api = api - self._tm_data = TransmissionData(hass, config_entry, api) - self.unsub_timer: Callable[[], None] | None = None - - @property - def api(self) -> TransmissionData: - """Return the TransmissionData object.""" - return self._tm_data - - async def async_setup(self) -> None: - """Set up the Transmission client.""" - await self.hass.async_add_executor_job(self.api.init_torrent_list) - await self.hass.async_add_executor_job(self.api.update) - self.add_options() - self.set_scan_interval(self.config_entry.options[CONF_SCAN_INTERVAL]) - - def register_services(self) -> None: - """Register integration services.""" - - def add_torrent(service: ServiceCall) -> None: - """Add new torrent to download.""" - if not (tm_client := _get_client(self.hass, service.data)): - raise ValueError("Transmission instance is not found") - - torrent = service.data[ATTR_TORRENT] - if torrent.startswith( - ("http", "ftp:", "magnet:") - ) or self.hass.config.is_allowed_path(torrent): - tm_client.tm_api.add_torrent(torrent) - tm_client.api.update() - else: - _LOGGER.warning( - "Could not add torrent: unsupported type or no permission" - ) - - def start_torrent(service: ServiceCall) -> None: - """Start torrent.""" - if not (tm_client := _get_client(self.hass, service.data)): - raise ValueError("Transmission instance is not found") - - torrent_id = service.data[CONF_ID] - tm_client.tm_api.start_torrent(torrent_id) - tm_client.api.update() - - def stop_torrent(service: ServiceCall) -> None: - """Stop torrent.""" - if not (tm_client := _get_client(self.hass, service.data)): - raise ValueError("Transmission instance is not found") - - torrent_id = service.data[CONF_ID] - tm_client.tm_api.stop_torrent(torrent_id) - tm_client.api.update() - - def remove_torrent(service: ServiceCall) -> None: - """Remove torrent.""" - if not (tm_client := _get_client(self.hass, service.data)): - raise ValueError("Transmission instance is not found") - - torrent_id = service.data[CONF_ID] - delete_data = service.data[ATTR_DELETE_DATA] - tm_client.tm_api.remove_torrent(torrent_id, delete_data=delete_data) - tm_client.api.update() - - self.hass.services.async_register( - DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA - ) - - self.hass.services.async_register( - DOMAIN, - SERVICE_REMOVE_TORRENT, - remove_torrent, - schema=SERVICE_REMOVE_TORRENT_SCHEMA, - ) - - self.hass.services.async_register( - DOMAIN, - SERVICE_START_TORRENT, - start_torrent, - schema=SERVICE_START_TORRENT_SCHEMA, - ) - - self.hass.services.async_register( - DOMAIN, - SERVICE_STOP_TORRENT, - stop_torrent, - schema=SERVICE_STOP_TORRENT_SCHEMA, - ) - - self.config_entry.add_update_listener(self.async_options_updated) - - def add_options(self): - """Add options for entry.""" - if not self.config_entry.options: - scan_interval = self.config_entry.data.get( - CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL - ) - limit = self.config_entry.data.get(CONF_LIMIT, DEFAULT_LIMIT) - order = self.config_entry.data.get(CONF_ORDER, DEFAULT_ORDER) - options = { - CONF_SCAN_INTERVAL: scan_interval, - CONF_LIMIT: limit, - CONF_ORDER: order, - } - - self.hass.config_entries.async_update_entry( - self.config_entry, options=options - ) - - def set_scan_interval(self, scan_interval: float) -> None: - """Update scan interval.""" - - def refresh(event_time: datetime) -> None: - """Get the latest data from Transmission.""" - self.api.update() - - if self.unsub_timer is not None: - self.unsub_timer() - self.unsub_timer = async_track_time_interval( - self.hass, refresh, timedelta(seconds=scan_interval) - ) - - @staticmethod - async def async_options_updated(hass: HomeAssistant, entry: ConfigEntry) -> None: - """Triggered by config entry options updates.""" - tm_client: TransmissionClient = hass.data[DOMAIN][entry.entry_id] - tm_client.set_scan_interval(entry.options[CONF_SCAN_INTERVAL]) - await hass.async_add_executor_job(tm_client.api.update) - - -class TransmissionData: - """Get the latest data and update the states.""" - - def __init__( - self, hass: HomeAssistant, config: ConfigEntry, api: transmission_rpc.Client - ) -> None: - """Initialize the Transmission RPC API.""" - self.hass = hass - self.config = config - self._api: transmission_rpc.Client = api - self.data: SessionStats | None = None - self.available: bool = True - self._session: transmission_rpc.Session | None = None - self._all_torrents: list[transmission_rpc.Torrent] = [] - self._completed_torrents: list[transmission_rpc.Torrent] = [] - self._started_torrents: list[transmission_rpc.Torrent] = [] - self._torrents: list[transmission_rpc.Torrent] = [] - - @property - def host(self) -> str: - """Return the host name.""" - return self.config.data[CONF_HOST] - - @property - def signal_update(self) -> str: - """Update signal per transmission entry.""" - return f"{DATA_UPDATED}-{self.host}" - - @property - def torrents(self) -> list[transmission_rpc.Torrent]: - """Get the list of torrents.""" - return self._torrents - - def update(self) -> None: - """Get the latest data from Transmission instance.""" - try: - self.data = self._api.session_stats() - self._torrents = self._api.get_torrents() - self._session = self._api.get_session() - - self.check_completed_torrent() - self.check_started_torrent() - self.check_removed_torrent() - _LOGGER.debug("Torrent Data for %s Updated", self.host) - - self.available = True - except TransmissionError: - self.available = False - _LOGGER.error("Unable to connect to Transmission client %s", self.host) - dispatcher_send(self.hass, self.signal_update) - - def init_torrent_list(self) -> None: - """Initialize torrent lists.""" - self._torrents = self._api.get_torrents() - self._completed_torrents = [ - torrent for torrent in self._torrents if torrent.status == "seeding" - ] - self._started_torrents = [ - torrent for torrent in self._torrents if torrent.status == "downloading" - ] - - def check_completed_torrent(self) -> None: - """Get completed torrent functionality.""" - old_completed_torrent_names = { - torrent.name for torrent in self._completed_torrents - } - - current_completed_torrents = [ - torrent for torrent in self._torrents if torrent.status == "seeding" - ] - - for torrent in current_completed_torrents: - if torrent.name not in old_completed_torrent_names: - self.hass.bus.fire( - EVENT_DOWNLOADED_TORRENT, {"name": torrent.name, "id": torrent.id} - ) - - self._completed_torrents = current_completed_torrents - - def check_started_torrent(self) -> None: - """Get started torrent functionality.""" - old_started_torrent_names = {torrent.name for torrent in self._started_torrents} - - current_started_torrents = [ - torrent for torrent in self._torrents if torrent.status == "downloading" - ] - - for torrent in current_started_torrents: - if torrent.name not in old_started_torrent_names: - self.hass.bus.fire( - EVENT_STARTED_TORRENT, {"name": torrent.name, "id": torrent.id} - ) - - self._started_torrents = current_started_torrents - - def check_removed_torrent(self) -> None: - """Get removed torrent functionality.""" - current_torrent_names = {torrent.name for torrent in self._torrents} - - for torrent in self._all_torrents: - if torrent.name not in current_torrent_names: - self.hass.bus.fire( - EVENT_REMOVED_TORRENT, {"name": torrent.name, "id": torrent.id} - ) - - self._all_torrents = self._torrents.copy() - - def start_torrents(self) -> None: - """Start all torrents.""" - if not self._torrents: - return - self._api.start_all() - - def stop_torrents(self) -> None: - """Stop all active torrents.""" - if not self._torrents: - return - torrent_ids = [torrent.id for torrent in self._torrents] - self._api.stop_torrent(torrent_ids) - - def set_alt_speed_enabled(self, is_enabled: bool) -> None: - """Set the alternative speed flag.""" - self._api.set_session(alt_speed_enabled=is_enabled) - - def get_alt_speed_enabled(self) -> bool | None: - """Get the alternative speed flag.""" - if self._session is None: - return None - - return self._session.alt_speed_enabled +async def async_options_updated(hass: HomeAssistant, entry: ConfigEntry) -> None: + """Triggered by config entry options updates.""" + coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id] + coordinator.update_interval = timedelta(seconds=entry.options[CONF_SCAN_INTERVAL]) + await coordinator.async_request_refresh() diff --git a/homeassistant/components/transmission/const.py b/homeassistant/components/transmission/const.py index da861d2698c..cb31d5a5aac 100644 --- a/homeassistant/components/transmission/const.py +++ b/homeassistant/components/transmission/const.py @@ -39,8 +39,6 @@ SERVICE_REMOVE_TORRENT = "remove_torrent" SERVICE_START_TORRENT = "start_torrent" SERVICE_STOP_TORRENT = "stop_torrent" -DATA_UPDATED = "transmission_data_updated" - EVENT_STARTED_TORRENT = "transmission_started_torrent" EVENT_REMOVED_TORRENT = "transmission_removed_torrent" EVENT_DOWNLOADED_TORRENT = "transmission_downloaded_torrent" diff --git a/homeassistant/components/transmission/coordinator.py b/homeassistant/components/transmission/coordinator.py new file mode 100644 index 00000000000..5fce7cae53d --- /dev/null +++ b/homeassistant/components/transmission/coordinator.py @@ -0,0 +1,166 @@ +"""Coordinator for transmssion integration.""" +from __future__ import annotations + +from datetime import timedelta +import logging + +import transmission_rpc +from transmission_rpc.session import SessionStats + +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import CONF_HOST, CONF_SCAN_INTERVAL +from homeassistant.core import HomeAssistant +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed + +from .const import ( + CONF_LIMIT, + CONF_ORDER, + DEFAULT_LIMIT, + DEFAULT_ORDER, + DEFAULT_SCAN_INTERVAL, + DOMAIN, + EVENT_DOWNLOADED_TORRENT, + EVENT_REMOVED_TORRENT, + EVENT_STARTED_TORRENT, +) + +_LOGGER = logging.getLogger(__name__) + + +class TransmissionDataUpdateCoordinator(DataUpdateCoordinator[SessionStats]): + """Transmission dataupdate coordinator class.""" + + config_entry: ConfigEntry + + def __init__( + self, hass: HomeAssistant, entry: ConfigEntry, api: transmission_rpc.Client + ) -> None: + """Initialize the Transmission RPC API.""" + self.config_entry = entry + self.api = api + self.host = entry.data[CONF_HOST] + self._session: transmission_rpc.Session | None = None + self._all_torrents: list[transmission_rpc.Torrent] = [] + self._completed_torrents: list[transmission_rpc.Torrent] = [] + self._started_torrents: list[transmission_rpc.Torrent] = [] + self.torrents: list[transmission_rpc.Torrent] = [] + super().__init__( + hass, + name=f"{DOMAIN} - {self.host}", + logger=_LOGGER, + update_interval=timedelta(seconds=self.scan_interval), + ) + + @property + def scan_interval(self) -> float: + """Return scan interval.""" + return self.config_entry.options.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) + + @property + def limit(self) -> int: + """Return limit.""" + return self.config_entry.data.get(CONF_LIMIT, DEFAULT_LIMIT) + + @property + def order(self) -> str: + """Return order.""" + return self.config_entry.data.get(CONF_ORDER, DEFAULT_ORDER) + + async def _async_update_data(self) -> SessionStats: + """Update transmission data.""" + return await self.hass.async_add_executor_job(self.update) + + def update(self) -> SessionStats: + """Get the latest data from Transmission instance.""" + try: + data = self.api.session_stats() + self.torrents = self.api.get_torrents() + self._session = self.api.get_session() + + self.check_completed_torrent() + self.check_started_torrent() + self.check_removed_torrent() + except transmission_rpc.TransmissionError as err: + raise UpdateFailed("Unable to connect to Transmission client") from err + + return data + + def init_torrent_list(self) -> None: + """Initialize torrent lists.""" + self.torrents = self.api.get_torrents() + self._completed_torrents = [ + torrent for torrent in self.torrents if torrent.status == "seeding" + ] + self._started_torrents = [ + torrent for torrent in self.torrents if torrent.status == "downloading" + ] + + def check_completed_torrent(self) -> None: + """Get completed torrent functionality.""" + old_completed_torrent_names = { + torrent.name for torrent in self._completed_torrents + } + + current_completed_torrents = [ + torrent for torrent in self.torrents if torrent.status == "seeding" + ] + + for torrent in current_completed_torrents: + if torrent.name not in old_completed_torrent_names: + self.hass.bus.fire( + EVENT_DOWNLOADED_TORRENT, {"name": torrent.name, "id": torrent.id} + ) + + self._completed_torrents = current_completed_torrents + + def check_started_torrent(self) -> None: + """Get started torrent functionality.""" + old_started_torrent_names = {torrent.name for torrent in self._started_torrents} + + current_started_torrents = [ + torrent for torrent in self.torrents if torrent.status == "downloading" + ] + + for torrent in current_started_torrents: + if torrent.name not in old_started_torrent_names: + self.hass.bus.fire( + EVENT_STARTED_TORRENT, {"name": torrent.name, "id": torrent.id} + ) + + self._started_torrents = current_started_torrents + + def check_removed_torrent(self) -> None: + """Get removed torrent functionality.""" + current_torrent_names = {torrent.name for torrent in self.torrents} + + for torrent in self._all_torrents: + if torrent.name not in current_torrent_names: + self.hass.bus.fire( + EVENT_REMOVED_TORRENT, {"name": torrent.name, "id": torrent.id} + ) + + self._all_torrents = self.torrents.copy() + + def start_torrents(self) -> None: + """Start all torrents.""" + if not self.torrents: + return + self.api.start_all() + + def stop_torrents(self) -> None: + """Stop all active torrents.""" + if not self.torrents: + return + torrent_ids = [torrent.id for torrent in self.torrents] + self.api.stop_torrent(torrent_ids) + + def set_alt_speed_enabled(self, is_enabled: bool) -> None: + """Set the alternative speed flag.""" + self.api.set_session(alt_speed_enabled=is_enabled) + + def get_alt_speed_enabled(self) -> bool | None: + """Get the alternative speed flag.""" + if self._session is None: + return None + + return self._session.alt_speed_enabled diff --git a/homeassistant/components/transmission/sensor.py b/homeassistant/components/transmission/sensor.py index 93bea8a25c9..0b949e73f47 100644 --- a/homeassistant/components/transmission/sensor.py +++ b/homeassistant/components/transmission/sensor.py @@ -4,21 +4,18 @@ from __future__ import annotations from contextlib import suppress from typing import Any +from transmission_rpc.session import SessionStats from transmission_rpc.torrent import Torrent from homeassistant.components.sensor import SensorDeviceClass, SensorEntity from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_NAME, STATE_IDLE, UnitOfDataRate -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo -from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.typing import StateType +from homeassistant.helpers.update_coordinator import CoordinatorEntity -from . import TransmissionClient from .const import ( - CONF_LIMIT, - CONF_ORDER, DOMAIN, STATE_ATTR_TORRENT_INFO, STATE_DOWNLOADING, @@ -26,6 +23,7 @@ from .const import ( STATE_UP_DOWN, SUPPORTED_ORDER_MODES, ) +from .coordinator import TransmissionDataUpdateCoordinator async def async_setup_entry( @@ -35,54 +33,56 @@ async def async_setup_entry( ) -> None: """Set up the Transmission sensors.""" - tm_client: TransmissionClient = hass.data[DOMAIN][config_entry.entry_id] + coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][ + config_entry.entry_id + ] name: str = config_entry.data[CONF_NAME] dev = [ TransmissionSpeedSensor( - tm_client, + coordinator, name, "download_speed", "download", ), TransmissionSpeedSensor( - tm_client, + coordinator, name, "upload_speed", "upload", ), TransmissionStatusSensor( - tm_client, + coordinator, name, "transmission_status", "status", ), TransmissionTorrentsSensor( - tm_client, + coordinator, name, "active_torrents", "active_torrents", ), TransmissionTorrentsSensor( - tm_client, + coordinator, name, "paused_torrents", "paused_torrents", ), TransmissionTorrentsSensor( - tm_client, + coordinator, name, "total_torrents", "total_torrents", ), TransmissionTorrentsSensor( - tm_client, + coordinator, name, "completed_torrents", "completed_torrents", ), TransmissionTorrentsSensor( - tm_client, + coordinator, name, "started_torrents", "started_torrents", @@ -92,7 +92,7 @@ async def async_setup_entry( async_add_entities(dev, True) -class TransmissionSensor(SensorEntity): +class TransmissionSensor(CoordinatorEntity[SessionStats], SensorEntity): """A base class for all Transmission sensors.""" _attr_has_entity_name = True @@ -100,48 +100,23 @@ class TransmissionSensor(SensorEntity): def __init__( self, - tm_client: TransmissionClient, + coordinator: TransmissionDataUpdateCoordinator, client_name: str, sensor_translation_key: str, key: str, ) -> None: """Initialize the sensor.""" - self._tm_client = tm_client + super().__init__(coordinator) self._attr_translation_key = sensor_translation_key self._key = key - self._state: StateType = None - self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{key}" + self._attr_unique_id = f"{coordinator.config_entry.entry_id}-{key}" self._attr_device_info = DeviceInfo( entry_type=DeviceEntryType.SERVICE, - identifiers={(DOMAIN, tm_client.config_entry.entry_id)}, + identifiers={(DOMAIN, coordinator.config_entry.entry_id)}, manufacturer="Transmission", name=client_name, ) - @property - def native_value(self) -> StateType: - """Return the state of the sensor.""" - return self._state - - @property - def available(self) -> bool: - """Could the device be accessed during the last update call.""" - return self._tm_client.api.available - - async def async_added_to_hass(self) -> None: - """Handle entity which will be added.""" - - @callback - def update(): - """Update the state.""" - self.async_schedule_update_ha_state(True) - - self.async_on_remove( - async_dispatcher_connect( - self.hass, self._tm_client.api.signal_update, update - ) - ) - class TransmissionSpeedSensor(TransmissionSensor): """Representation of a Transmission speed sensor.""" @@ -151,15 +126,15 @@ class TransmissionSpeedSensor(TransmissionSensor): _attr_suggested_display_precision = 2 _attr_suggested_unit_of_measurement = UnitOfDataRate.MEGABYTES_PER_SECOND - def update(self) -> None: - """Get the latest data from Transmission and updates the state.""" - if data := self._tm_client.api.data: - b_spd = ( - float(data.download_speed) - if self._key == "download" - else float(data.upload_speed) - ) - self._state = b_spd + @property + def native_value(self) -> float: + """Return the speed of the sensor.""" + data = self.coordinator.data + return ( + float(data.download_speed) + if self._key == "download" + else float(data.upload_speed) + ) class TransmissionStatusSensor(TransmissionSensor): @@ -168,21 +143,18 @@ class TransmissionStatusSensor(TransmissionSensor): _attr_device_class = SensorDeviceClass.ENUM _attr_options = [STATE_IDLE, STATE_UP_DOWN, STATE_SEEDING, STATE_DOWNLOADING] - def update(self) -> None: - """Get the latest data from Transmission and updates the state.""" - if data := self._tm_client.api.data: - upload = data.upload_speed - download = data.download_speed - if upload > 0 and download > 0: - self._state = STATE_UP_DOWN - elif upload > 0 and download == 0: - self._state = STATE_SEEDING - elif upload == 0 and download > 0: - self._state = STATE_DOWNLOADING - else: - self._state = STATE_IDLE - else: - self._state = None + @property + def native_value(self) -> str: + """Return the value of the status sensor.""" + upload = self.coordinator.data.upload_speed + download = self.coordinator.data.download_speed + if upload > 0 and download > 0: + return STATE_UP_DOWN + if upload > 0 and download == 0: + return STATE_SEEDING + if upload == 0 and download > 0: + return STATE_DOWNLOADING + return STATE_IDLE class TransmissionTorrentsSensor(TransmissionSensor): @@ -208,21 +180,22 @@ class TransmissionTorrentsSensor(TransmissionSensor): def extra_state_attributes(self) -> dict[str, Any]: """Return the state attributes, if any.""" info = _torrents_info( - torrents=self._tm_client.api.torrents, - order=self._tm_client.config_entry.options[CONF_ORDER], - limit=self._tm_client.config_entry.options[CONF_LIMIT], + torrents=self.coordinator.torrents, + order=self.coordinator.order, + limit=self.coordinator.limit, statuses=self.MODES[self._key], ) return { STATE_ATTR_TORRENT_INFO: info, } - def update(self) -> None: - """Get the latest data from Transmission and updates the state.""" + @property + def native_value(self) -> int: + """Return the count of the sensor.""" torrents = _filter_torrents( - self._tm_client.api.torrents, statuses=self.MODES[self._key] + self.coordinator.torrents, statuses=self.MODES[self._key] ) - self._state = len(torrents) + return len(torrents) def _filter_torrents( diff --git a/homeassistant/components/transmission/switch.py b/homeassistant/components/transmission/switch.py index fad099fc5b9..253ceb558b9 100644 --- a/homeassistant/components/transmission/switch.py +++ b/homeassistant/components/transmission/switch.py @@ -3,16 +3,18 @@ from collections.abc import Callable import logging from typing import Any +from transmission_rpc.session import SessionStats + from homeassistant.components.switch import SwitchEntity from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_NAME, STATE_OFF, STATE_ON -from homeassistant.core import HomeAssistant, callback +from homeassistant.const import CONF_NAME +from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo -from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.helpers.update_coordinator import CoordinatorEntity -from . import TransmissionClient from .const import DOMAIN, SWITCH_TYPES +from .coordinator import TransmissionDataUpdateCoordinator _LOGGING = logging.getLogger(__name__) @@ -24,17 +26,19 @@ async def async_setup_entry( ) -> None: """Set up the Transmission switch.""" - tm_client: TransmissionClient = hass.data[DOMAIN][config_entry.entry_id] + coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][ + config_entry.entry_id + ] name: str = config_entry.data[CONF_NAME] dev = [] for switch_type, switch_name in SWITCH_TYPES.items(): - dev.append(TransmissionSwitch(switch_type, switch_name, tm_client, name)) + dev.append(TransmissionSwitch(switch_type, switch_name, coordinator, name)) async_add_entities(dev, True) -class TransmissionSwitch(SwitchEntity): +class TransmissionSwitch(CoordinatorEntity[SessionStats], SwitchEntity): """Representation of a Transmission switch.""" _attr_has_entity_name = True @@ -44,20 +48,18 @@ class TransmissionSwitch(SwitchEntity): self, switch_type: str, switch_name: str, - tm_client: TransmissionClient, + coordinator: TransmissionDataUpdateCoordinator, client_name: str, ) -> None: """Initialize the Transmission switch.""" + super().__init__(coordinator) self._attr_name = switch_name self.type = switch_type - self._tm_client = tm_client - self._state = STATE_OFF - self._data = None self.unsub_update: Callable[[], None] | None = None - self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{switch_type}" + self._attr_unique_id = f"{coordinator.config_entry.entry_id}-{switch_type}" self._attr_device_info = DeviceInfo( entry_type=DeviceEntryType.SERVICE, - identifiers={(DOMAIN, tm_client.config_entry.entry_id)}, + identifiers={(DOMAIN, coordinator.config_entry.entry_id)}, manufacturer="Transmission", name=client_name, ) @@ -65,63 +67,34 @@ class TransmissionSwitch(SwitchEntity): @property def is_on(self) -> bool: """Return true if device is on.""" - return self._state == STATE_ON + active = None + if self.type == "on_off": + active = self.coordinator.data.active_torrent_count > 0 + elif self.type == "turtle_mode": + active = self.coordinator.get_alt_speed_enabled() - @property - def available(self) -> bool: - """Could the device be accessed during the last update call.""" - return self._tm_client.api.available + return bool(active) - def turn_on(self, **kwargs: Any) -> None: + async def async_turn_on(self, **kwargs: Any) -> None: """Turn the device on.""" if self.type == "on_off": _LOGGING.debug("Starting all torrents") - self._tm_client.api.start_torrents() + await self.hass.async_add_executor_job(self.coordinator.start_torrents) elif self.type == "turtle_mode": _LOGGING.debug("Turning Turtle Mode of Transmission on") - self._tm_client.api.set_alt_speed_enabled(True) - self._tm_client.api.update() + await self.hass.async_add_executor_job( + self.coordinator.set_alt_speed_enabled, True + ) + await self.coordinator.async_request_refresh() - def turn_off(self, **kwargs: Any) -> None: + async def async_turn_off(self, **kwargs: Any) -> None: """Turn the device off.""" if self.type == "on_off": _LOGGING.debug("Stopping all torrents") - self._tm_client.api.stop_torrents() + await self.hass.async_add_executor_job(self.coordinator.stop_torrents) if self.type == "turtle_mode": _LOGGING.debug("Turning Turtle Mode of Transmission off") - self._tm_client.api.set_alt_speed_enabled(False) - self._tm_client.api.update() - - async def async_added_to_hass(self) -> None: - """Handle entity which will be added.""" - self.unsub_update = async_dispatcher_connect( - self.hass, - self._tm_client.api.signal_update, - self._schedule_immediate_update, - ) - - @callback - def _schedule_immediate_update(self) -> None: - self.async_schedule_update_ha_state(True) - - async def will_remove_from_hass(self) -> None: - """Unsubscribe from update dispatcher.""" - if self.unsub_update: - self.unsub_update() - self.unsub_update = None - - def update(self) -> None: - """Get the latest data from Transmission and updates the state.""" - active = None - if self.type == "on_off": - self._data = self._tm_client.api.data - if self._data: - active = self._data.active_torrent_count > 0 - - elif self.type == "turtle_mode": - active = self._tm_client.api.get_alt_speed_enabled() - - if active is None: - return - - self._state = STATE_ON if active else STATE_OFF + await self.hass.async_add_executor_job( + self.coordinator.set_alt_speed_enabled, False + ) + await self.coordinator.async_request_refresh()