Use DataUpdate coordinator for Transmission (#99209)

* Switch integration to DataUpdate Coordinator

* add coordinator to .coveragerc

* Migrate TransmissionData into DUC

* update coveragerc

* Applu suggestions

* remove CONFIG_SCHEMA
This commit is contained in:
Rami Mosleh 2023-10-12 21:58:22 +03:00 committed by GitHub
parent cc3d1a11bd
commit 536ad57bf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 340 additions and 465 deletions

View File

@ -1396,6 +1396,7 @@ omit =
homeassistant/components/trafikverket_weatherstation/coordinator.py homeassistant/components/trafikverket_weatherstation/coordinator.py
homeassistant/components/trafikverket_weatherstation/sensor.py homeassistant/components/trafikverket_weatherstation/sensor.py
homeassistant/components/transmission/__init__.py homeassistant/components/transmission/__init__.py
homeassistant/components/transmission/coordinator.py
homeassistant/components/transmission/sensor.py homeassistant/components/transmission/sensor.py
homeassistant/components/transmission/switch.py homeassistant/components/transmission/switch.py
homeassistant/components/travisci/sensor.py homeassistant/components/travisci/sensor.py

View File

@ -1,8 +1,7 @@
"""Support for the Transmission BitTorrent client API.""" """Support for the Transmission BitTorrent client API."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from datetime import timedelta
from datetime import datetime, timedelta
from functools import partial from functools import partial
import logging import logging
import re import re
@ -14,10 +13,9 @@ from transmission_rpc.error import (
TransmissionConnectError, TransmissionConnectError,
TransmissionError, TransmissionError,
) )
from transmission_rpc.session import SessionStats
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_ID, CONF_ID,
@ -35,33 +33,39 @@ from homeassistant.helpers import (
entity_registry as er, entity_registry as er,
selector, selector,
) )
from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from .const import ( from .const import (
ATTR_DELETE_DATA, ATTR_DELETE_DATA,
ATTR_TORRENT, ATTR_TORRENT,
CONF_ENTRY_ID, CONF_ENTRY_ID,
CONF_LIMIT,
CONF_ORDER,
DATA_UPDATED,
DEFAULT_DELETE_DATA, DEFAULT_DELETE_DATA,
DEFAULT_LIMIT,
DEFAULT_ORDER,
DEFAULT_SCAN_INTERVAL,
DOMAIN, DOMAIN,
EVENT_DOWNLOADED_TORRENT,
EVENT_REMOVED_TORRENT,
EVENT_STARTED_TORRENT,
SERVICE_ADD_TORRENT, SERVICE_ADD_TORRENT,
SERVICE_REMOVE_TORRENT, SERVICE_REMOVE_TORRENT,
SERVICE_START_TORRENT, SERVICE_START_TORRENT,
SERVICE_STOP_TORRENT, SERVICE_STOP_TORRENT,
) )
from .coordinator import TransmissionDataUpdateCoordinator
from .errors import AuthenticationError, CannotConnect, UnknownError from .errors import AuthenticationError, CannotConnect, UnknownError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PLATFORMS = [Platform.SENSOR, Platform.SWITCH]
MIGRATION_NAME_TO_KEY = {
# Sensors
"Down Speed": "download",
"Up Speed": "upload",
"Status": "status",
"Active Torrents": "active_torrents",
"Paused Torrents": "paused_torrents",
"Total Torrents": "total_torrents",
"Completed Torrents": "completed_torrents",
"Started Torrents": "started_torrents",
# Switches
"Switch": "on_off",
"Turtle Mode": "turtle_mode",
}
SERVICE_BASE_SCHEMA = vol.Schema( SERVICE_BASE_SCHEMA = vol.Schema(
{ {
@ -95,25 +99,6 @@ SERVICE_STOP_TORRENT_SCHEMA = vol.All(
) )
) )
CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False)
PLATFORMS = [Platform.SENSOR, Platform.SWITCH]
MIGRATION_NAME_TO_KEY = {
# Sensors
"Down Speed": "download",
"Up Speed": "upload",
"Status": "status",
"Active Torrents": "active_torrents",
"Paused Torrents": "paused_torrents",
"Total Torrents": "total_torrents",
"Completed Torrents": "completed_torrents",
"Started Torrents": "started_torrents",
# Switches
"Switch": "on_off",
"Turtle Mode": "turtle_mode",
}
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Set up the Transmission Component.""" """Set up the Transmission Component."""
@ -141,24 +126,81 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
except (AuthenticationError, UnknownError) as error: except (AuthenticationError, UnknownError) as error:
raise ConfigEntryAuthFailed from error raise ConfigEntryAuthFailed from error
client = TransmissionClient(hass, config_entry, api) coordinator = TransmissionDataUpdateCoordinator(hass, config_entry, api)
await client.async_setup() await hass.async_add_executor_job(coordinator.init_torrent_list)
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = coordinator
await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(config_entry, PLATFORMS)
client.register_services() config_entry.add_update_listener(async_options_updated)
async def add_torrent(service: ServiceCall) -> None:
"""Add new torrent to download."""
torrent = service.data[ATTR_TORRENT]
if torrent.startswith(
("http", "ftp:", "magnet:")
) or hass.config.is_allowed_path(torrent):
await hass.async_add_executor_job(coordinator.api.add_torrent, torrent)
await coordinator.async_request_refresh()
else:
_LOGGER.warning("Could not add torrent: unsupported type or no permission")
async def start_torrent(service: ServiceCall) -> None:
"""Start torrent."""
torrent_id = service.data[CONF_ID]
await hass.async_add_executor_job(coordinator.api.start_torrent, torrent_id)
await coordinator.async_request_refresh()
async def stop_torrent(service: ServiceCall) -> None:
"""Stop torrent."""
torrent_id = service.data[CONF_ID]
await hass.async_add_executor_job(coordinator.api.stop_torrent, torrent_id)
await coordinator.async_request_refresh()
async def remove_torrent(service: ServiceCall) -> None:
"""Remove torrent."""
torrent_id = service.data[CONF_ID]
delete_data = service.data[ATTR_DELETE_DATA]
await hass.async_add_executor_job(
partial(coordinator.api.remove_torrent, torrent_id, delete_data=delete_data)
)
await coordinator.async_request_refresh()
hass.services.async_register(
DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
)
hass.services.async_register(
DOMAIN,
SERVICE_REMOVE_TORRENT,
remove_torrent,
schema=SERVICE_REMOVE_TORRENT_SCHEMA,
)
hass.services.async_register(
DOMAIN,
SERVICE_START_TORRENT,
start_torrent,
schema=SERVICE_START_TORRENT_SCHEMA,
)
hass.services.async_register(
DOMAIN,
SERVICE_STOP_TORRENT,
stop_torrent,
schema=SERVICE_STOP_TORRENT_SCHEMA,
)
return True return True
async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Unload Transmission Entry from config_entry.""" """Unload Transmission Entry from config_entry."""
client: TransmissionClient = hass.data[DOMAIN].pop(config_entry.entry_id) if unload_ok := await hass.config_entries.async_unload_platforms(
if client.unsub_timer:
client.unsub_timer()
unload_ok = await hass.config_entries.async_unload_platforms(
config_entry, PLATFORMS config_entry, PLATFORMS
) ):
hass.data[DOMAIN].pop(config_entry.entry_id)
if not hass.data[DOMAIN]: if not hass.data[DOMAIN]:
hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT) hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT)
@ -202,286 +244,8 @@ async def get_api(
raise UnknownError from error raise UnknownError from error
def _get_client(hass: HomeAssistant, data: dict[str, Any]) -> TransmissionClient | None: async def async_options_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Return client from integration name or entry_id."""
if (
(entry_id := data.get(CONF_ENTRY_ID))
and (entry := hass.config_entries.async_get_entry(entry_id))
and entry.state == ConfigEntryState.LOADED
):
return hass.data[DOMAIN][entry_id]
return None
class TransmissionClient:
"""Transmission Client Object."""
def __init__(
self,
hass: HomeAssistant,
config_entry: ConfigEntry,
api: transmission_rpc.Client,
) -> None:
"""Initialize the Transmission RPC API."""
self.hass = hass
self.config_entry = config_entry
self.tm_api = api
self._tm_data = TransmissionData(hass, config_entry, api)
self.unsub_timer: Callable[[], None] | None = None
@property
def api(self) -> TransmissionData:
"""Return the TransmissionData object."""
return self._tm_data
async def async_setup(self) -> None:
"""Set up the Transmission client."""
await self.hass.async_add_executor_job(self.api.init_torrent_list)
await self.hass.async_add_executor_job(self.api.update)
self.add_options()
self.set_scan_interval(self.config_entry.options[CONF_SCAN_INTERVAL])
def register_services(self) -> None:
"""Register integration services."""
def add_torrent(service: ServiceCall) -> None:
"""Add new torrent to download."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent = service.data[ATTR_TORRENT]
if torrent.startswith(
("http", "ftp:", "magnet:")
) or self.hass.config.is_allowed_path(torrent):
tm_client.tm_api.add_torrent(torrent)
tm_client.api.update()
else:
_LOGGER.warning(
"Could not add torrent: unsupported type or no permission"
)
def start_torrent(service: ServiceCall) -> None:
"""Start torrent."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent_id = service.data[CONF_ID]
tm_client.tm_api.start_torrent(torrent_id)
tm_client.api.update()
def stop_torrent(service: ServiceCall) -> None:
"""Stop torrent."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent_id = service.data[CONF_ID]
tm_client.tm_api.stop_torrent(torrent_id)
tm_client.api.update()
def remove_torrent(service: ServiceCall) -> None:
"""Remove torrent."""
if not (tm_client := _get_client(self.hass, service.data)):
raise ValueError("Transmission instance is not found")
torrent_id = service.data[CONF_ID]
delete_data = service.data[ATTR_DELETE_DATA]
tm_client.tm_api.remove_torrent(torrent_id, delete_data=delete_data)
tm_client.api.update()
self.hass.services.async_register(
DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
)
self.hass.services.async_register(
DOMAIN,
SERVICE_REMOVE_TORRENT,
remove_torrent,
schema=SERVICE_REMOVE_TORRENT_SCHEMA,
)
self.hass.services.async_register(
DOMAIN,
SERVICE_START_TORRENT,
start_torrent,
schema=SERVICE_START_TORRENT_SCHEMA,
)
self.hass.services.async_register(
DOMAIN,
SERVICE_STOP_TORRENT,
stop_torrent,
schema=SERVICE_STOP_TORRENT_SCHEMA,
)
self.config_entry.add_update_listener(self.async_options_updated)
def add_options(self):
"""Add options for entry."""
if not self.config_entry.options:
scan_interval = self.config_entry.data.get(
CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
)
limit = self.config_entry.data.get(CONF_LIMIT, DEFAULT_LIMIT)
order = self.config_entry.data.get(CONF_ORDER, DEFAULT_ORDER)
options = {
CONF_SCAN_INTERVAL: scan_interval,
CONF_LIMIT: limit,
CONF_ORDER: order,
}
self.hass.config_entries.async_update_entry(
self.config_entry, options=options
)
def set_scan_interval(self, scan_interval: float) -> None:
"""Update scan interval."""
def refresh(event_time: datetime) -> None:
"""Get the latest data from Transmission."""
self.api.update()
if self.unsub_timer is not None:
self.unsub_timer()
self.unsub_timer = async_track_time_interval(
self.hass, refresh, timedelta(seconds=scan_interval)
)
@staticmethod
async def async_options_updated(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Triggered by config entry options updates.""" """Triggered by config entry options updates."""
tm_client: TransmissionClient = hass.data[DOMAIN][entry.entry_id] coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][entry.entry_id]
tm_client.set_scan_interval(entry.options[CONF_SCAN_INTERVAL]) coordinator.update_interval = timedelta(seconds=entry.options[CONF_SCAN_INTERVAL])
await hass.async_add_executor_job(tm_client.api.update) await coordinator.async_request_refresh()
class TransmissionData:
"""Get the latest data and update the states."""
def __init__(
self, hass: HomeAssistant, config: ConfigEntry, api: transmission_rpc.Client
) -> None:
"""Initialize the Transmission RPC API."""
self.hass = hass
self.config = config
self._api: transmission_rpc.Client = api
self.data: SessionStats | None = None
self.available: bool = True
self._session: transmission_rpc.Session | None = None
self._all_torrents: list[transmission_rpc.Torrent] = []
self._completed_torrents: list[transmission_rpc.Torrent] = []
self._started_torrents: list[transmission_rpc.Torrent] = []
self._torrents: list[transmission_rpc.Torrent] = []
@property
def host(self) -> str:
"""Return the host name."""
return self.config.data[CONF_HOST]
@property
def signal_update(self) -> str:
"""Update signal per transmission entry."""
return f"{DATA_UPDATED}-{self.host}"
@property
def torrents(self) -> list[transmission_rpc.Torrent]:
"""Get the list of torrents."""
return self._torrents
def update(self) -> None:
"""Get the latest data from Transmission instance."""
try:
self.data = self._api.session_stats()
self._torrents = self._api.get_torrents()
self._session = self._api.get_session()
self.check_completed_torrent()
self.check_started_torrent()
self.check_removed_torrent()
_LOGGER.debug("Torrent Data for %s Updated", self.host)
self.available = True
except TransmissionError:
self.available = False
_LOGGER.error("Unable to connect to Transmission client %s", self.host)
dispatcher_send(self.hass, self.signal_update)
def init_torrent_list(self) -> None:
"""Initialize torrent lists."""
self._torrents = self._api.get_torrents()
self._completed_torrents = [
torrent for torrent in self._torrents if torrent.status == "seeding"
]
self._started_torrents = [
torrent for torrent in self._torrents if torrent.status == "downloading"
]
def check_completed_torrent(self) -> None:
"""Get completed torrent functionality."""
old_completed_torrent_names = {
torrent.name for torrent in self._completed_torrents
}
current_completed_torrents = [
torrent for torrent in self._torrents if torrent.status == "seeding"
]
for torrent in current_completed_torrents:
if torrent.name not in old_completed_torrent_names:
self.hass.bus.fire(
EVENT_DOWNLOADED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._completed_torrents = current_completed_torrents
def check_started_torrent(self) -> None:
"""Get started torrent functionality."""
old_started_torrent_names = {torrent.name for torrent in self._started_torrents}
current_started_torrents = [
torrent for torrent in self._torrents if torrent.status == "downloading"
]
for torrent in current_started_torrents:
if torrent.name not in old_started_torrent_names:
self.hass.bus.fire(
EVENT_STARTED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._started_torrents = current_started_torrents
def check_removed_torrent(self) -> None:
"""Get removed torrent functionality."""
current_torrent_names = {torrent.name for torrent in self._torrents}
for torrent in self._all_torrents:
if torrent.name not in current_torrent_names:
self.hass.bus.fire(
EVENT_REMOVED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._all_torrents = self._torrents.copy()
def start_torrents(self) -> None:
"""Start all torrents."""
if not self._torrents:
return
self._api.start_all()
def stop_torrents(self) -> None:
"""Stop all active torrents."""
if not self._torrents:
return
torrent_ids = [torrent.id for torrent in self._torrents]
self._api.stop_torrent(torrent_ids)
def set_alt_speed_enabled(self, is_enabled: bool) -> None:
"""Set the alternative speed flag."""
self._api.set_session(alt_speed_enabled=is_enabled)
def get_alt_speed_enabled(self) -> bool | None:
"""Get the alternative speed flag."""
if self._session is None:
return None
return self._session.alt_speed_enabled

View File

@ -39,8 +39,6 @@ SERVICE_REMOVE_TORRENT = "remove_torrent"
SERVICE_START_TORRENT = "start_torrent" SERVICE_START_TORRENT = "start_torrent"
SERVICE_STOP_TORRENT = "stop_torrent" SERVICE_STOP_TORRENT = "stop_torrent"
DATA_UPDATED = "transmission_data_updated"
EVENT_STARTED_TORRENT = "transmission_started_torrent" EVENT_STARTED_TORRENT = "transmission_started_torrent"
EVENT_REMOVED_TORRENT = "transmission_removed_torrent" EVENT_REMOVED_TORRENT = "transmission_removed_torrent"
EVENT_DOWNLOADED_TORRENT = "transmission_downloaded_torrent" EVENT_DOWNLOADED_TORRENT = "transmission_downloaded_torrent"

View File

@ -0,0 +1,166 @@
"""Coordinator for transmssion integration."""
from __future__ import annotations
from datetime import timedelta
import logging
import transmission_rpc
from transmission_rpc.session import SessionStats
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_HOST, CONF_SCAN_INTERVAL
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import (
CONF_LIMIT,
CONF_ORDER,
DEFAULT_LIMIT,
DEFAULT_ORDER,
DEFAULT_SCAN_INTERVAL,
DOMAIN,
EVENT_DOWNLOADED_TORRENT,
EVENT_REMOVED_TORRENT,
EVENT_STARTED_TORRENT,
)
_LOGGER = logging.getLogger(__name__)
class TransmissionDataUpdateCoordinator(DataUpdateCoordinator[SessionStats]):
"""Transmission dataupdate coordinator class."""
config_entry: ConfigEntry
def __init__(
self, hass: HomeAssistant, entry: ConfigEntry, api: transmission_rpc.Client
) -> None:
"""Initialize the Transmission RPC API."""
self.config_entry = entry
self.api = api
self.host = entry.data[CONF_HOST]
self._session: transmission_rpc.Session | None = None
self._all_torrents: list[transmission_rpc.Torrent] = []
self._completed_torrents: list[transmission_rpc.Torrent] = []
self._started_torrents: list[transmission_rpc.Torrent] = []
self.torrents: list[transmission_rpc.Torrent] = []
super().__init__(
hass,
name=f"{DOMAIN} - {self.host}",
logger=_LOGGER,
update_interval=timedelta(seconds=self.scan_interval),
)
@property
def scan_interval(self) -> float:
"""Return scan interval."""
return self.config_entry.options.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL)
@property
def limit(self) -> int:
"""Return limit."""
return self.config_entry.data.get(CONF_LIMIT, DEFAULT_LIMIT)
@property
def order(self) -> str:
"""Return order."""
return self.config_entry.data.get(CONF_ORDER, DEFAULT_ORDER)
async def _async_update_data(self) -> SessionStats:
"""Update transmission data."""
return await self.hass.async_add_executor_job(self.update)
def update(self) -> SessionStats:
"""Get the latest data from Transmission instance."""
try:
data = self.api.session_stats()
self.torrents = self.api.get_torrents()
self._session = self.api.get_session()
self.check_completed_torrent()
self.check_started_torrent()
self.check_removed_torrent()
except transmission_rpc.TransmissionError as err:
raise UpdateFailed("Unable to connect to Transmission client") from err
return data
def init_torrent_list(self) -> None:
"""Initialize torrent lists."""
self.torrents = self.api.get_torrents()
self._completed_torrents = [
torrent for torrent in self.torrents if torrent.status == "seeding"
]
self._started_torrents = [
torrent for torrent in self.torrents if torrent.status == "downloading"
]
def check_completed_torrent(self) -> None:
"""Get completed torrent functionality."""
old_completed_torrent_names = {
torrent.name for torrent in self._completed_torrents
}
current_completed_torrents = [
torrent for torrent in self.torrents if torrent.status == "seeding"
]
for torrent in current_completed_torrents:
if torrent.name not in old_completed_torrent_names:
self.hass.bus.fire(
EVENT_DOWNLOADED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._completed_torrents = current_completed_torrents
def check_started_torrent(self) -> None:
"""Get started torrent functionality."""
old_started_torrent_names = {torrent.name for torrent in self._started_torrents}
current_started_torrents = [
torrent for torrent in self.torrents if torrent.status == "downloading"
]
for torrent in current_started_torrents:
if torrent.name not in old_started_torrent_names:
self.hass.bus.fire(
EVENT_STARTED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._started_torrents = current_started_torrents
def check_removed_torrent(self) -> None:
"""Get removed torrent functionality."""
current_torrent_names = {torrent.name for torrent in self.torrents}
for torrent in self._all_torrents:
if torrent.name not in current_torrent_names:
self.hass.bus.fire(
EVENT_REMOVED_TORRENT, {"name": torrent.name, "id": torrent.id}
)
self._all_torrents = self.torrents.copy()
def start_torrents(self) -> None:
"""Start all torrents."""
if not self.torrents:
return
self.api.start_all()
def stop_torrents(self) -> None:
"""Stop all active torrents."""
if not self.torrents:
return
torrent_ids = [torrent.id for torrent in self.torrents]
self.api.stop_torrent(torrent_ids)
def set_alt_speed_enabled(self, is_enabled: bool) -> None:
"""Set the alternative speed flag."""
self.api.set_session(alt_speed_enabled=is_enabled)
def get_alt_speed_enabled(self) -> bool | None:
"""Get the alternative speed flag."""
if self._session is None:
return None
return self._session.alt_speed_enabled

View File

@ -4,21 +4,18 @@ from __future__ import annotations
from contextlib import suppress from contextlib import suppress
from typing import Any from typing import Any
from transmission_rpc.session import SessionStats
from transmission_rpc.torrent import Torrent from transmission_rpc.torrent import Torrent
from homeassistant.components.sensor import SensorDeviceClass, SensorEntity from homeassistant.components.sensor import SensorDeviceClass, SensorEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME, STATE_IDLE, UnitOfDataRate from homeassistant.const import CONF_NAME, STATE_IDLE, UnitOfDataRate
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import TransmissionClient
from .const import ( from .const import (
CONF_LIMIT,
CONF_ORDER,
DOMAIN, DOMAIN,
STATE_ATTR_TORRENT_INFO, STATE_ATTR_TORRENT_INFO,
STATE_DOWNLOADING, STATE_DOWNLOADING,
@ -26,6 +23,7 @@ from .const import (
STATE_UP_DOWN, STATE_UP_DOWN,
SUPPORTED_ORDER_MODES, SUPPORTED_ORDER_MODES,
) )
from .coordinator import TransmissionDataUpdateCoordinator
async def async_setup_entry( async def async_setup_entry(
@ -35,54 +33,56 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up the Transmission sensors.""" """Set up the Transmission sensors."""
tm_client: TransmissionClient = hass.data[DOMAIN][config_entry.entry_id] coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][
config_entry.entry_id
]
name: str = config_entry.data[CONF_NAME] name: str = config_entry.data[CONF_NAME]
dev = [ dev = [
TransmissionSpeedSensor( TransmissionSpeedSensor(
tm_client, coordinator,
name, name,
"download_speed", "download_speed",
"download", "download",
), ),
TransmissionSpeedSensor( TransmissionSpeedSensor(
tm_client, coordinator,
name, name,
"upload_speed", "upload_speed",
"upload", "upload",
), ),
TransmissionStatusSensor( TransmissionStatusSensor(
tm_client, coordinator,
name, name,
"transmission_status", "transmission_status",
"status", "status",
), ),
TransmissionTorrentsSensor( TransmissionTorrentsSensor(
tm_client, coordinator,
name, name,
"active_torrents", "active_torrents",
"active_torrents", "active_torrents",
), ),
TransmissionTorrentsSensor( TransmissionTorrentsSensor(
tm_client, coordinator,
name, name,
"paused_torrents", "paused_torrents",
"paused_torrents", "paused_torrents",
), ),
TransmissionTorrentsSensor( TransmissionTorrentsSensor(
tm_client, coordinator,
name, name,
"total_torrents", "total_torrents",
"total_torrents", "total_torrents",
), ),
TransmissionTorrentsSensor( TransmissionTorrentsSensor(
tm_client, coordinator,
name, name,
"completed_torrents", "completed_torrents",
"completed_torrents", "completed_torrents",
), ),
TransmissionTorrentsSensor( TransmissionTorrentsSensor(
tm_client, coordinator,
name, name,
"started_torrents", "started_torrents",
"started_torrents", "started_torrents",
@ -92,7 +92,7 @@ async def async_setup_entry(
async_add_entities(dev, True) async_add_entities(dev, True)
class TransmissionSensor(SensorEntity): class TransmissionSensor(CoordinatorEntity[SessionStats], SensorEntity):
"""A base class for all Transmission sensors.""" """A base class for all Transmission sensors."""
_attr_has_entity_name = True _attr_has_entity_name = True
@ -100,48 +100,23 @@ class TransmissionSensor(SensorEntity):
def __init__( def __init__(
self, self,
tm_client: TransmissionClient, coordinator: TransmissionDataUpdateCoordinator,
client_name: str, client_name: str,
sensor_translation_key: str, sensor_translation_key: str,
key: str, key: str,
) -> None: ) -> None:
"""Initialize the sensor.""" """Initialize the sensor."""
self._tm_client = tm_client super().__init__(coordinator)
self._attr_translation_key = sensor_translation_key self._attr_translation_key = sensor_translation_key
self._key = key self._key = key
self._state: StateType = None self._attr_unique_id = f"{coordinator.config_entry.entry_id}-{key}"
self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{key}"
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE, entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, tm_client.config_entry.entry_id)}, identifiers={(DOMAIN, coordinator.config_entry.entry_id)},
manufacturer="Transmission", manufacturer="Transmission",
name=client_name, name=client_name,
) )
@property
def native_value(self) -> StateType:
"""Return the state of the sensor."""
return self._state
@property
def available(self) -> bool:
"""Could the device be accessed during the last update call."""
return self._tm_client.api.available
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
@callback
def update():
"""Update the state."""
self.async_schedule_update_ha_state(True)
self.async_on_remove(
async_dispatcher_connect(
self.hass, self._tm_client.api.signal_update, update
)
)
class TransmissionSpeedSensor(TransmissionSensor): class TransmissionSpeedSensor(TransmissionSensor):
"""Representation of a Transmission speed sensor.""" """Representation of a Transmission speed sensor."""
@ -151,15 +126,15 @@ class TransmissionSpeedSensor(TransmissionSensor):
_attr_suggested_display_precision = 2 _attr_suggested_display_precision = 2
_attr_suggested_unit_of_measurement = UnitOfDataRate.MEGABYTES_PER_SECOND _attr_suggested_unit_of_measurement = UnitOfDataRate.MEGABYTES_PER_SECOND
def update(self) -> None: @property
"""Get the latest data from Transmission and updates the state.""" def native_value(self) -> float:
if data := self._tm_client.api.data: """Return the speed of the sensor."""
b_spd = ( data = self.coordinator.data
return (
float(data.download_speed) float(data.download_speed)
if self._key == "download" if self._key == "download"
else float(data.upload_speed) else float(data.upload_speed)
) )
self._state = b_spd
class TransmissionStatusSensor(TransmissionSensor): class TransmissionStatusSensor(TransmissionSensor):
@ -168,21 +143,18 @@ class TransmissionStatusSensor(TransmissionSensor):
_attr_device_class = SensorDeviceClass.ENUM _attr_device_class = SensorDeviceClass.ENUM
_attr_options = [STATE_IDLE, STATE_UP_DOWN, STATE_SEEDING, STATE_DOWNLOADING] _attr_options = [STATE_IDLE, STATE_UP_DOWN, STATE_SEEDING, STATE_DOWNLOADING]
def update(self) -> None: @property
"""Get the latest data from Transmission and updates the state.""" def native_value(self) -> str:
if data := self._tm_client.api.data: """Return the value of the status sensor."""
upload = data.upload_speed upload = self.coordinator.data.upload_speed
download = data.download_speed download = self.coordinator.data.download_speed
if upload > 0 and download > 0: if upload > 0 and download > 0:
self._state = STATE_UP_DOWN return STATE_UP_DOWN
elif upload > 0 and download == 0: if upload > 0 and download == 0:
self._state = STATE_SEEDING return STATE_SEEDING
elif upload == 0 and download > 0: if upload == 0 and download > 0:
self._state = STATE_DOWNLOADING return STATE_DOWNLOADING
else: return STATE_IDLE
self._state = STATE_IDLE
else:
self._state = None
class TransmissionTorrentsSensor(TransmissionSensor): class TransmissionTorrentsSensor(TransmissionSensor):
@ -208,21 +180,22 @@ class TransmissionTorrentsSensor(TransmissionSensor):
def extra_state_attributes(self) -> dict[str, Any]: def extra_state_attributes(self) -> dict[str, Any]:
"""Return the state attributes, if any.""" """Return the state attributes, if any."""
info = _torrents_info( info = _torrents_info(
torrents=self._tm_client.api.torrents, torrents=self.coordinator.torrents,
order=self._tm_client.config_entry.options[CONF_ORDER], order=self.coordinator.order,
limit=self._tm_client.config_entry.options[CONF_LIMIT], limit=self.coordinator.limit,
statuses=self.MODES[self._key], statuses=self.MODES[self._key],
) )
return { return {
STATE_ATTR_TORRENT_INFO: info, STATE_ATTR_TORRENT_INFO: info,
} }
def update(self) -> None: @property
"""Get the latest data from Transmission and updates the state.""" def native_value(self) -> int:
"""Return the count of the sensor."""
torrents = _filter_torrents( torrents = _filter_torrents(
self._tm_client.api.torrents, statuses=self.MODES[self._key] self.coordinator.torrents, statuses=self.MODES[self._key]
) )
self._state = len(torrents) return len(torrents)
def _filter_torrents( def _filter_torrents(

View File

@ -3,16 +3,18 @@ from collections.abc import Callable
import logging import logging
from typing import Any from typing import Any
from transmission_rpc.session import SessionStats
from homeassistant.components.switch import SwitchEntity from homeassistant.components.switch import SwitchEntity
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME, STATE_OFF, STATE_ON from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import TransmissionClient
from .const import DOMAIN, SWITCH_TYPES from .const import DOMAIN, SWITCH_TYPES
from .coordinator import TransmissionDataUpdateCoordinator
_LOGGING = logging.getLogger(__name__) _LOGGING = logging.getLogger(__name__)
@ -24,17 +26,19 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up the Transmission switch.""" """Set up the Transmission switch."""
tm_client: TransmissionClient = hass.data[DOMAIN][config_entry.entry_id] coordinator: TransmissionDataUpdateCoordinator = hass.data[DOMAIN][
config_entry.entry_id
]
name: str = config_entry.data[CONF_NAME] name: str = config_entry.data[CONF_NAME]
dev = [] dev = []
for switch_type, switch_name in SWITCH_TYPES.items(): for switch_type, switch_name in SWITCH_TYPES.items():
dev.append(TransmissionSwitch(switch_type, switch_name, tm_client, name)) dev.append(TransmissionSwitch(switch_type, switch_name, coordinator, name))
async_add_entities(dev, True) async_add_entities(dev, True)
class TransmissionSwitch(SwitchEntity): class TransmissionSwitch(CoordinatorEntity[SessionStats], SwitchEntity):
"""Representation of a Transmission switch.""" """Representation of a Transmission switch."""
_attr_has_entity_name = True _attr_has_entity_name = True
@ -44,20 +48,18 @@ class TransmissionSwitch(SwitchEntity):
self, self,
switch_type: str, switch_type: str,
switch_name: str, switch_name: str,
tm_client: TransmissionClient, coordinator: TransmissionDataUpdateCoordinator,
client_name: str, client_name: str,
) -> None: ) -> None:
"""Initialize the Transmission switch.""" """Initialize the Transmission switch."""
super().__init__(coordinator)
self._attr_name = switch_name self._attr_name = switch_name
self.type = switch_type self.type = switch_type
self._tm_client = tm_client
self._state = STATE_OFF
self._data = None
self.unsub_update: Callable[[], None] | None = None self.unsub_update: Callable[[], None] | None = None
self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{switch_type}" self._attr_unique_id = f"{coordinator.config_entry.entry_id}-{switch_type}"
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE, entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, tm_client.config_entry.entry_id)}, identifiers={(DOMAIN, coordinator.config_entry.entry_id)},
manufacturer="Transmission", manufacturer="Transmission",
name=client_name, name=client_name,
) )
@ -65,63 +67,34 @@ class TransmissionSwitch(SwitchEntity):
@property @property
def is_on(self) -> bool: def is_on(self) -> bool:
"""Return true if device is on.""" """Return true if device is on."""
return self._state == STATE_ON active = None
if self.type == "on_off":
active = self.coordinator.data.active_torrent_count > 0
elif self.type == "turtle_mode":
active = self.coordinator.get_alt_speed_enabled()
@property return bool(active)
def available(self) -> bool:
"""Could the device be accessed during the last update call."""
return self._tm_client.api.available
def turn_on(self, **kwargs: Any) -> None: async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn the device on.""" """Turn the device on."""
if self.type == "on_off": if self.type == "on_off":
_LOGGING.debug("Starting all torrents") _LOGGING.debug("Starting all torrents")
self._tm_client.api.start_torrents() await self.hass.async_add_executor_job(self.coordinator.start_torrents)
elif self.type == "turtle_mode": elif self.type == "turtle_mode":
_LOGGING.debug("Turning Turtle Mode of Transmission on") _LOGGING.debug("Turning Turtle Mode of Transmission on")
self._tm_client.api.set_alt_speed_enabled(True) await self.hass.async_add_executor_job(
self._tm_client.api.update() self.coordinator.set_alt_speed_enabled, True
)
await self.coordinator.async_request_refresh()
def turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn the device off.""" """Turn the device off."""
if self.type == "on_off": if self.type == "on_off":
_LOGGING.debug("Stopping all torrents") _LOGGING.debug("Stopping all torrents")
self._tm_client.api.stop_torrents() await self.hass.async_add_executor_job(self.coordinator.stop_torrents)
if self.type == "turtle_mode": if self.type == "turtle_mode":
_LOGGING.debug("Turning Turtle Mode of Transmission off") _LOGGING.debug("Turning Turtle Mode of Transmission off")
self._tm_client.api.set_alt_speed_enabled(False) await self.hass.async_add_executor_job(
self._tm_client.api.update() self.coordinator.set_alt_speed_enabled, False
async def async_added_to_hass(self) -> None:
"""Handle entity which will be added."""
self.unsub_update = async_dispatcher_connect(
self.hass,
self._tm_client.api.signal_update,
self._schedule_immediate_update,
) )
await self.coordinator.async_request_refresh()
@callback
def _schedule_immediate_update(self) -> None:
self.async_schedule_update_ha_state(True)
async def will_remove_from_hass(self) -> None:
"""Unsubscribe from update dispatcher."""
if self.unsub_update:
self.unsub_update()
self.unsub_update = None
def update(self) -> None:
"""Get the latest data from Transmission and updates the state."""
active = None
if self.type == "on_off":
self._data = self._tm_client.api.data
if self._data:
active = self._data.active_torrent_count > 0
elif self.type == "turtle_mode":
active = self._tm_client.api.get_alt_speed_enabled()
if active is None:
return
self._state = STATE_ON if active else STATE_OFF