mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 14:17:45 +00:00
Allow multiple Transmission clients and add unique_id to entities (#28136)
* Allow multiple clients + improvements * remove commented code * fixed test_init.py
This commit is contained in:
parent
062ec8a7c2
commit
7cb6607b1f
@ -703,7 +703,6 @@ omit =
|
||||
homeassistant/components/tradfri/base_class.py
|
||||
homeassistant/components/trafikverket_train/sensor.py
|
||||
homeassistant/components/trafikverket_weatherstation/sensor.py
|
||||
homeassistant/components/transmission/__init__.py
|
||||
homeassistant/components/transmission/sensor.py
|
||||
homeassistant/components/transmission/switch.py
|
||||
homeassistant/components/transmission/const.py
|
||||
|
@ -1,39 +1,34 @@
|
||||
{
|
||||
"config": {
|
||||
"abort": {
|
||||
"one_instance_allowed": "Only a single instance is necessary."
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "Unable to Connect to host",
|
||||
"wrong_credentials": "Wrong username or password"
|
||||
},
|
||||
"title": "Transmission",
|
||||
"step": {
|
||||
"options": {
|
||||
"data": {
|
||||
"scan_interval": "Update frequency"
|
||||
},
|
||||
"title": "Configure Options"
|
||||
},
|
||||
"user": {
|
||||
"title": "Setup Transmission Client",
|
||||
"data": {
|
||||
"host": "Host",
|
||||
"name": "Name",
|
||||
"host": "Host",
|
||||
"username": "Username",
|
||||
"password": "Password",
|
||||
"port": "Port",
|
||||
"username": "Username"
|
||||
},
|
||||
"title": "Setup Transmission Client"
|
||||
"port": "Port"
|
||||
}
|
||||
}
|
||||
},
|
||||
"title": "Transmission"
|
||||
"error": {
|
||||
"name_exists": "Name already exists",
|
||||
"wrong_credentials": "Wrong username or password",
|
||||
"cannot_connect": "Unable to Connect to host"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "Host is already configured."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"title": "Configure options for Transmission",
|
||||
"data": {
|
||||
"scan_interval": "Update frequency"
|
||||
},
|
||||
"description": "Configure options for Transmission"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,11 +19,10 @@ from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.dispatcher import dispatcher_send
|
||||
from homeassistant.helpers.event import async_track_time_interval
|
||||
from homeassistant.util import slugify
|
||||
|
||||
from .const import (
|
||||
ATTR_TORRENT,
|
||||
DATA_TRANSMISSION,
|
||||
DATA_UPDATED,
|
||||
DEFAULT_NAME,
|
||||
DEFAULT_PORT,
|
||||
DEFAULT_SCAN_INTERVAL,
|
||||
@ -37,74 +36,77 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SERVICE_ADD_TORRENT_SCHEMA = vol.Schema({vol.Required(ATTR_TORRENT): cv.string})
|
||||
|
||||
TRANS_SCHEMA = vol.All(
|
||||
vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST): cv.string,
|
||||
vol.Optional(CONF_PASSWORD): cv.string,
|
||||
vol.Optional(CONF_USERNAME): cv.string,
|
||||
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
|
||||
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
|
||||
vol.Optional(
|
||||
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
|
||||
): cv.time_period,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST): cv.string,
|
||||
vol.Optional(CONF_PASSWORD): cv.string,
|
||||
vol.Optional(CONF_USERNAME): cv.string,
|
||||
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
|
||||
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
|
||||
vol.Optional(
|
||||
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
|
||||
): cv.time_period,
|
||||
}
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
{DOMAIN: vol.All(cv.ensure_list, [TRANS_SCHEMA])}, extra=vol.ALLOW_EXTRA
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Import the Transmission Component from config."""
|
||||
if not hass.config_entries.async_entries(DOMAIN) and DOMAIN in config:
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_IMPORT}, data=config[DOMAIN]
|
||||
if DOMAIN in config:
|
||||
for entry in config[DOMAIN]:
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_IMPORT}, data=entry
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_setup_entry(hass, config_entry):
|
||||
"""Set up the Transmission Component."""
|
||||
if DOMAIN not in hass.data:
|
||||
hass.data[DOMAIN] = {}
|
||||
|
||||
if not config_entry.options:
|
||||
await async_populate_options(hass, config_entry)
|
||||
|
||||
client = TransmissionClient(hass, config_entry)
|
||||
client_id = config_entry.entry_id
|
||||
hass.data[DOMAIN][client_id] = client
|
||||
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client
|
||||
|
||||
if not await client.async_setup():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_unload_entry(hass, entry):
|
||||
async def async_unload_entry(hass, config_entry):
|
||||
"""Unload Transmission Entry from config_entry."""
|
||||
hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT)
|
||||
if hass.data[DOMAIN][entry.entry_id].unsub_timer:
|
||||
hass.data[DOMAIN][entry.entry_id].unsub_timer()
|
||||
client = hass.data[DOMAIN][config_entry.entry_id]
|
||||
hass.services.async_remove(DOMAIN, client.service_name)
|
||||
if client.unsub_timer:
|
||||
client.unsub_timer()
|
||||
|
||||
for component in "sensor", "switch":
|
||||
await hass.config_entries.async_forward_entry_unload(entry, component)
|
||||
await hass.config_entries.async_forward_entry_unload(config_entry, component)
|
||||
|
||||
del hass.data[DOMAIN]
|
||||
hass.data[DOMAIN].pop(config_entry.entry_id)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def get_api(hass, host, port, username=None, password=None):
|
||||
async def get_api(hass, entry):
|
||||
"""Get Transmission client."""
|
||||
host = entry[CONF_HOST]
|
||||
port = entry[CONF_PORT]
|
||||
username = entry.get(CONF_USERNAME)
|
||||
password = entry.get(CONF_PASSWORD)
|
||||
|
||||
try:
|
||||
api = await hass.async_add_executor_job(
|
||||
transmissionrpc.Client, host, port, username, password
|
||||
)
|
||||
_LOGGER.debug("Successfully connected to %s", host)
|
||||
return api
|
||||
|
||||
except TransmissionError as error:
|
||||
@ -112,20 +114,13 @@ async def get_api(hass, host, port, username=None, password=None):
|
||||
_LOGGER.error("Credentials for Transmission client are not valid")
|
||||
raise AuthenticationError
|
||||
if "111: Connection refused" in str(error):
|
||||
_LOGGER.error("Connecting to the Transmission client failed")
|
||||
_LOGGER.error("Connecting to the Transmission client %s failed", host)
|
||||
raise CannotConnect
|
||||
|
||||
_LOGGER.error(error)
|
||||
raise UnknownError
|
||||
|
||||
|
||||
async def async_populate_options(hass, config_entry):
|
||||
"""Populate default options for Transmission Client."""
|
||||
options = {CONF_SCAN_INTERVAL: config_entry.data["options"][CONF_SCAN_INTERVAL]}
|
||||
|
||||
hass.config_entries.async_update_entry(config_entry, options=options)
|
||||
|
||||
|
||||
class TransmissionClient:
|
||||
"""Transmission Client Object."""
|
||||
|
||||
@ -133,33 +128,35 @@ class TransmissionClient:
|
||||
"""Initialize the Transmission RPC API."""
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.scan_interval = self.config_entry.options[CONF_SCAN_INTERVAL]
|
||||
self.tm_data = None
|
||||
self._tm_data = None
|
||||
self.unsub_timer = None
|
||||
|
||||
@property
|
||||
def service_name(self):
|
||||
"""Return the service name."""
|
||||
return slugify(f"{SERVICE_ADD_TORRENT}_{self.config_entry.data[CONF_NAME]}")
|
||||
|
||||
@property
|
||||
def api(self):
|
||||
"""Return the tm_data object."""
|
||||
return self._tm_data
|
||||
|
||||
async def async_setup(self):
|
||||
"""Set up the Transmission client."""
|
||||
|
||||
config = {
|
||||
CONF_HOST: self.config_entry.data[CONF_HOST],
|
||||
CONF_PORT: self.config_entry.data[CONF_PORT],
|
||||
CONF_USERNAME: self.config_entry.data.get(CONF_USERNAME),
|
||||
CONF_PASSWORD: self.config_entry.data.get(CONF_PASSWORD),
|
||||
}
|
||||
try:
|
||||
api = await get_api(self.hass, **config)
|
||||
api = await get_api(self.hass, self.config_entry.data)
|
||||
except CannotConnect:
|
||||
raise ConfigEntryNotReady
|
||||
except (AuthenticationError, UnknownError):
|
||||
return False
|
||||
|
||||
self.tm_data = self.hass.data[DOMAIN][DATA_TRANSMISSION] = TransmissionData(
|
||||
self.hass, self.config_entry, api
|
||||
)
|
||||
self._tm_data = TransmissionData(self.hass, self.config_entry, api)
|
||||
|
||||
await self.hass.async_add_executor_job(self.tm_data.init_torrent_list)
|
||||
await self.hass.async_add_executor_job(self.tm_data.update)
|
||||
self.set_scan_interval(self.scan_interval)
|
||||
await self.hass.async_add_executor_job(self._tm_data.init_torrent_list)
|
||||
await self.hass.async_add_executor_job(self._tm_data.update)
|
||||
self.add_options()
|
||||
self.set_scan_interval(self.config_entry.options[CONF_SCAN_INTERVAL])
|
||||
|
||||
for platform in ["sensor", "switch"]:
|
||||
self.hass.async_create_task(
|
||||
@ -181,19 +178,31 @@ class TransmissionClient:
|
||||
)
|
||||
|
||||
self.hass.services.async_register(
|
||||
DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
|
||||
DOMAIN, self.service_name, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
|
||||
)
|
||||
|
||||
self.config_entry.add_update_listener(self.async_options_updated)
|
||||
|
||||
return True
|
||||
|
||||
def add_options(self):
|
||||
"""Add options for entry."""
|
||||
if not self.config_entry.options:
|
||||
scan_interval = self.config_entry.data.pop(
|
||||
CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
|
||||
)
|
||||
options = {CONF_SCAN_INTERVAL: scan_interval}
|
||||
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, options=options
|
||||
)
|
||||
|
||||
def set_scan_interval(self, scan_interval):
|
||||
"""Update scan interval."""
|
||||
|
||||
def refresh(event_time):
|
||||
async def refresh(event_time):
|
||||
"""Get the latest data from Transmission."""
|
||||
self.tm_data.update()
|
||||
self._tm_data.update()
|
||||
|
||||
if self.unsub_timer is not None:
|
||||
self.unsub_timer()
|
||||
@ -215,6 +224,7 @@ class TransmissionData:
|
||||
def __init__(self, hass, config, api):
|
||||
"""Initialize the Transmission RPC API."""
|
||||
self.hass = hass
|
||||
self.config = config
|
||||
self.data = None
|
||||
self.torrents = None
|
||||
self.session = None
|
||||
@ -223,6 +233,16 @@ class TransmissionData:
|
||||
self.completed_torrents = []
|
||||
self.started_torrents = []
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
"""Return the host name."""
|
||||
return self.config.data[CONF_HOST]
|
||||
|
||||
@property
|
||||
def signal_options_update(self):
|
||||
"""Option update signal per transmission entry."""
|
||||
return f"tm-options-{self.host}"
|
||||
|
||||
def update(self):
|
||||
"""Get the latest data from Transmission instance."""
|
||||
try:
|
||||
@ -232,14 +252,13 @@ class TransmissionData:
|
||||
|
||||
self.check_completed_torrent()
|
||||
self.check_started_torrent()
|
||||
_LOGGER.debug("Torrent Data Updated")
|
||||
_LOGGER.debug("Torrent Data for %s Updated", self.host)
|
||||
|
||||
self.available = True
|
||||
except TransmissionError:
|
||||
self.available = False
|
||||
_LOGGER.error("Unable to connect to Transmission client")
|
||||
|
||||
dispatcher_send(self.hass, DATA_UPDATED)
|
||||
_LOGGER.error("Unable to connect to Transmission client %s", self.host)
|
||||
dispatcher_send(self.hass, self.signal_options_update)
|
||||
|
||||
def init_torrent_list(self):
|
||||
"""Initialize torrent lists."""
|
||||
|
@ -29,32 +29,32 @@ class TransmissionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Get the options flow for this handler."""
|
||||
return TransmissionOptionsFlowHandler(config_entry)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Transmission flow."""
|
||||
self.config = {}
|
||||
self.errors = {}
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle a flow initialized by the user."""
|
||||
if self.hass.config_entries.async_entries(DOMAIN):
|
||||
return self.async_abort(reason="one_instance_allowed")
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
|
||||
self.config[CONF_NAME] = user_input.pop(CONF_NAME)
|
||||
for entry in self.hass.config_entries.async_entries(DOMAIN):
|
||||
if entry.data[CONF_HOST] == user_input[CONF_HOST]:
|
||||
return self.async_abort(reason="already_configured")
|
||||
if entry.data[CONF_NAME] == user_input[CONF_NAME]:
|
||||
errors[CONF_NAME] = "name_exists"
|
||||
break
|
||||
|
||||
try:
|
||||
await get_api(self.hass, **user_input)
|
||||
self.config.update(user_input)
|
||||
if "options" not in self.config:
|
||||
self.config["options"] = {CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL}
|
||||
return self.async_create_entry(
|
||||
title=self.config[CONF_NAME], data=self.config
|
||||
)
|
||||
await get_api(self.hass, user_input)
|
||||
|
||||
except AuthenticationError:
|
||||
self.errors[CONF_USERNAME] = "wrong_credentials"
|
||||
self.errors[CONF_PASSWORD] = "wrong_credentials"
|
||||
errors[CONF_USERNAME] = "wrong_credentials"
|
||||
errors[CONF_PASSWORD] = "wrong_credentials"
|
||||
except (CannotConnect, UnknownError):
|
||||
self.errors["base"] = "cannot_connect"
|
||||
errors["base"] = "cannot_connect"
|
||||
|
||||
if not errors:
|
||||
return self.async_create_entry(
|
||||
title=user_input[CONF_NAME], data=user_input
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
@ -67,15 +67,12 @@ class TransmissionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
vol.Required(CONF_PORT, default=DEFAULT_PORT): int,
|
||||
}
|
||||
),
|
||||
errors=self.errors,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_import(self, import_config):
|
||||
"""Import from Transmission client config."""
|
||||
self.config["options"] = {
|
||||
CONF_SCAN_INTERVAL: import_config.pop(CONF_SCAN_INTERVAL).seconds
|
||||
}
|
||||
|
||||
import_config[CONF_SCAN_INTERVAL] = import_config[CONF_SCAN_INTERVAL].seconds
|
||||
return await self.async_step_user(user_input=import_config)
|
||||
|
||||
|
||||
@ -95,8 +92,7 @@ class TransmissionOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
vol.Optional(
|
||||
CONF_SCAN_INTERVAL,
|
||||
default=self.config_entry.options.get(
|
||||
CONF_SCAN_INTERVAL,
|
||||
self.config_entry.data["options"][CONF_SCAN_INTERVAL],
|
||||
CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
|
||||
),
|
||||
): int
|
||||
}
|
||||
|
@ -21,4 +21,3 @@ ATTR_TORRENT = "torrent"
|
||||
SERVICE_ADD_TORRENT = "add_torrent"
|
||||
|
||||
DATA_UPDATED = "transmission_data_updated"
|
||||
DATA_TRANSMISSION = "data_transmission"
|
||||
|
@ -6,7 +6,7 @@ from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity import Entity
|
||||
|
||||
from .const import DATA_TRANSMISSION, DATA_UPDATED, DOMAIN, SENSOR_TYPES
|
||||
from .const import DOMAIN, SENSOR_TYPES
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -19,7 +19,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Set up the Transmission sensors."""
|
||||
|
||||
transmission_api = hass.data[DOMAIN][DATA_TRANSMISSION]
|
||||
tm_client = hass.data[DOMAIN][config_entry.entry_id]
|
||||
name = config_entry.data[CONF_NAME]
|
||||
|
||||
dev = []
|
||||
@ -27,7 +27,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
dev.append(
|
||||
TransmissionSensor(
|
||||
sensor_type,
|
||||
transmission_api,
|
||||
tm_client,
|
||||
name,
|
||||
SENSOR_TYPES[sensor_type][0],
|
||||
SENSOR_TYPES[sensor_type][1],
|
||||
@ -41,17 +41,12 @@ class TransmissionSensor(Entity):
|
||||
"""Representation of a Transmission sensor."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sensor_type,
|
||||
transmission_api,
|
||||
client_name,
|
||||
sensor_name,
|
||||
unit_of_measurement,
|
||||
self, sensor_type, tm_client, client_name, sensor_name, unit_of_measurement
|
||||
):
|
||||
"""Initialize the sensor."""
|
||||
self._name = sensor_name
|
||||
self._state = None
|
||||
self._transmission_api = transmission_api
|
||||
self._tm_client = tm_client
|
||||
self._unit_of_measurement = unit_of_measurement
|
||||
self._data = None
|
||||
self.client_name = client_name
|
||||
@ -62,6 +57,11 @@ class TransmissionSensor(Entity):
|
||||
"""Return the name of the sensor."""
|
||||
return f"{self.client_name} {self._name}"
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return the unique id of the entity."""
|
||||
return f"{self._tm_client.api.host}-{self.name}"
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the state of the sensor."""
|
||||
@ -80,12 +80,14 @@ class TransmissionSensor(Entity):
|
||||
@property
|
||||
def available(self):
|
||||
"""Could the device be accessed during the last update call."""
|
||||
return self._transmission_api.available
|
||||
return self._tm_client.api.available
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle entity which will be added."""
|
||||
async_dispatcher_connect(
|
||||
self.hass, DATA_UPDATED, self._schedule_immediate_update
|
||||
self.hass,
|
||||
self._tm_client.api.signal_options_update,
|
||||
self._schedule_immediate_update,
|
||||
)
|
||||
|
||||
@callback
|
||||
@ -94,12 +96,12 @@ class TransmissionSensor(Entity):
|
||||
|
||||
def update(self):
|
||||
"""Get the latest data from Transmission and updates the state."""
|
||||
self._data = self._transmission_api.data
|
||||
self._data = self._tm_client.api.data
|
||||
|
||||
if self.type == "completed_torrents":
|
||||
self._state = self._transmission_api.get_completed_torrent_count()
|
||||
self._state = self._tm_client.api.get_completed_torrent_count()
|
||||
elif self.type == "started_torrents":
|
||||
self._state = self._transmission_api.get_started_torrent_count()
|
||||
self._state = self._tm_client.api.get_started_torrent_count()
|
||||
|
||||
if self.type == "current_status":
|
||||
if self._data:
|
||||
|
@ -11,30 +11,25 @@
|
||||
"password": "Password",
|
||||
"port": "Port"
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"title": "Configure Options",
|
||||
"data": {
|
||||
"scan_interval": "Update frequency"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
"name_exists": "Name already exists",
|
||||
"wrong_credentials": "Wrong username or password",
|
||||
"cannot_connect": "Unable to Connect to host"
|
||||
},
|
||||
"abort": {
|
||||
"one_instance_allowed": "Only a single instance is necessary."
|
||||
"already_configured": "Host is already configured."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
"step": {
|
||||
"init": {
|
||||
"description": "Configure options for Transmission",
|
||||
"title": "Configure options for Transmission",
|
||||
"data": {
|
||||
"scan_interval": "Update frequency"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -6,7 +6,7 @@ from homeassistant.core import callback
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_connect
|
||||
from homeassistant.helpers.entity import ToggleEntity
|
||||
|
||||
from .const import DATA_TRANSMISSION, DATA_UPDATED, DOMAIN, SWITCH_TYPES
|
||||
from .const import DOMAIN, SWITCH_TYPES
|
||||
|
||||
_LOGGING = logging.getLogger(__name__)
|
||||
|
||||
@ -19,12 +19,12 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
|
||||
async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
"""Set up the Transmission switch."""
|
||||
|
||||
transmission_api = hass.data[DOMAIN][DATA_TRANSMISSION]
|
||||
tm_client = hass.data[DOMAIN][config_entry.entry_id]
|
||||
name = config_entry.data[CONF_NAME]
|
||||
|
||||
dev = []
|
||||
for switch_type, switch_name in SWITCH_TYPES.items():
|
||||
dev.append(TransmissionSwitch(switch_type, switch_name, transmission_api, name))
|
||||
dev.append(TransmissionSwitch(switch_type, switch_name, tm_client, name))
|
||||
|
||||
async_add_entities(dev, True)
|
||||
|
||||
@ -32,12 +32,12 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
|
||||
class TransmissionSwitch(ToggleEntity):
|
||||
"""Representation of a Transmission switch."""
|
||||
|
||||
def __init__(self, switch_type, switch_name, transmission_api, name):
|
||||
def __init__(self, switch_type, switch_name, tm_client, name):
|
||||
"""Initialize the Transmission switch."""
|
||||
self._name = switch_name
|
||||
self.client_name = name
|
||||
self.type = switch_type
|
||||
self._transmission_api = transmission_api
|
||||
self._tm_client = tm_client
|
||||
self._state = STATE_OFF
|
||||
self._data = None
|
||||
|
||||
@ -46,6 +46,11 @@ class TransmissionSwitch(ToggleEntity):
|
||||
"""Return the name of the switch."""
|
||||
return f"{self.client_name} {self._name}"
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return the unique id of the entity."""
|
||||
return f"{self._tm_client.api.host}-{self.name}"
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the state of the device."""
|
||||
@ -64,32 +69,34 @@ class TransmissionSwitch(ToggleEntity):
|
||||
@property
|
||||
def available(self):
|
||||
"""Could the device be accessed during the last update call."""
|
||||
return self._transmission_api.available
|
||||
return self._tm_client.api.available
|
||||
|
||||
def turn_on(self, **kwargs):
|
||||
"""Turn the device on."""
|
||||
if self.type == "on_off":
|
||||
_LOGGING.debug("Starting all torrents")
|
||||
self._transmission_api.start_torrents()
|
||||
self._tm_client.api.start_torrents()
|
||||
elif self.type == "turtle_mode":
|
||||
_LOGGING.debug("Turning Turtle Mode of Transmission on")
|
||||
self._transmission_api.set_alt_speed_enabled(True)
|
||||
self._transmission_api.update()
|
||||
self._tm_client.api.set_alt_speed_enabled(True)
|
||||
self._tm_client.api.update()
|
||||
|
||||
def turn_off(self, **kwargs):
|
||||
"""Turn the device off."""
|
||||
if self.type == "on_off":
|
||||
_LOGGING.debug("Stoping all torrents")
|
||||
self._transmission_api.stop_torrents()
|
||||
self._tm_client.api.stop_torrents()
|
||||
if self.type == "turtle_mode":
|
||||
_LOGGING.debug("Turning Turtle Mode of Transmission off")
|
||||
self._transmission_api.set_alt_speed_enabled(False)
|
||||
self._transmission_api.update()
|
||||
self._tm_client.api.set_alt_speed_enabled(False)
|
||||
self._tm_client.api.update()
|
||||
|
||||
async def async_added_to_hass(self):
|
||||
"""Handle entity which will be added."""
|
||||
async_dispatcher_connect(
|
||||
self.hass, DATA_UPDATED, self._schedule_immediate_update
|
||||
self.hass,
|
||||
self._tm_client.api.signal_options_update,
|
||||
self._schedule_immediate_update,
|
||||
)
|
||||
|
||||
@callback
|
||||
@ -100,12 +107,12 @@ class TransmissionSwitch(ToggleEntity):
|
||||
"""Get the latest data from Transmission and updates the state."""
|
||||
active = None
|
||||
if self.type == "on_off":
|
||||
self._data = self._transmission_api.data
|
||||
self._data = self._tm_client.api.data
|
||||
if self._data:
|
||||
active = self._data.activeTorrentCount > 0
|
||||
|
||||
elif self.type == "turtle_mode":
|
||||
active = self._transmission_api.get_alt_speed_enabled()
|
||||
active = self._tm_client.api.get_alt_speed_enabled()
|
||||
|
||||
if active is None:
|
||||
return
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Tests for Met.no config flow."""
|
||||
"""Tests for Transmission config flow."""
|
||||
from datetime import timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -31,6 +31,14 @@ PASSWORD = "password"
|
||||
PORT = 9091
|
||||
SCAN_INTERVAL = 10
|
||||
|
||||
MOCK_ENTRY = {
|
||||
CONF_NAME: NAME,
|
||||
CONF_HOST: HOST,
|
||||
CONF_USERNAME: USERNAME,
|
||||
CONF_PASSWORD: PASSWORD,
|
||||
CONF_PORT: PORT,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(name="api")
|
||||
def mock_transmission_api():
|
||||
@ -90,18 +98,10 @@ async def test_flow_works(hass, api):
|
||||
assert result["data"][CONF_NAME] == NAME
|
||||
assert result["data"][CONF_HOST] == HOST
|
||||
assert result["data"][CONF_PORT] == PORT
|
||||
assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
|
||||
# assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
|
||||
|
||||
# test with all provided
|
||||
result = await flow.async_step_user(
|
||||
{
|
||||
CONF_NAME: NAME,
|
||||
CONF_HOST: HOST,
|
||||
CONF_USERNAME: USERNAME,
|
||||
CONF_PASSWORD: PASSWORD,
|
||||
CONF_PORT: PORT,
|
||||
}
|
||||
)
|
||||
result = await flow.async_step_user(MOCK_ENTRY)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["title"] == NAME
|
||||
@ -110,7 +110,7 @@ async def test_flow_works(hass, api):
|
||||
assert result["data"][CONF_USERNAME] == USERNAME
|
||||
assert result["data"][CONF_PASSWORD] == PASSWORD
|
||||
assert result["data"][CONF_PORT] == PORT
|
||||
assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
|
||||
# assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
|
||||
|
||||
|
||||
async def test_options(hass):
|
||||
@ -118,14 +118,7 @@ async def test_options(hass):
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
title=CONF_NAME,
|
||||
data={
|
||||
"name": DEFAULT_NAME,
|
||||
"host": HOST,
|
||||
"username": USERNAME,
|
||||
"password": PASSWORD,
|
||||
"port": DEFAULT_PORT,
|
||||
"options": {CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
|
||||
},
|
||||
data=MOCK_ENTRY,
|
||||
options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
|
||||
)
|
||||
flow = init_config_flow(hass)
|
||||
@ -157,7 +150,7 @@ async def test_import(hass, api):
|
||||
assert result["data"][CONF_NAME] == DEFAULT_NAME
|
||||
assert result["data"][CONF_HOST] == HOST
|
||||
assert result["data"][CONF_PORT] == DEFAULT_PORT
|
||||
assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
|
||||
assert result["data"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
|
||||
|
||||
# import with all
|
||||
result = await flow.async_step_import(
|
||||
@ -177,18 +170,40 @@ async def test_import(hass, api):
|
||||
assert result["data"][CONF_USERNAME] == USERNAME
|
||||
assert result["data"][CONF_PASSWORD] == PASSWORD
|
||||
assert result["data"][CONF_PORT] == PORT
|
||||
assert result["data"]["options"][CONF_SCAN_INTERVAL] == SCAN_INTERVAL
|
||||
assert result["data"][CONF_SCAN_INTERVAL] == SCAN_INTERVAL
|
||||
|
||||
|
||||
async def test_integration_already_exists(hass, api):
|
||||
"""Test we only allow a single config flow."""
|
||||
MockConfigEntry(domain=DOMAIN).add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": "user"}
|
||||
async def test_host_already_configured(hass, api):
|
||||
"""Test host is already configured."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=MOCK_ENTRY,
|
||||
options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
flow = init_config_flow(hass)
|
||||
result = await flow.async_step_user(MOCK_ENTRY)
|
||||
|
||||
assert result["type"] == "abort"
|
||||
assert result["reason"] == "one_instance_allowed"
|
||||
assert result["reason"] == "already_configured"
|
||||
|
||||
|
||||
async def test_name_already_configured(hass, api):
|
||||
"""Test name is already configured."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=MOCK_ENTRY,
|
||||
options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
mock_entry = MOCK_ENTRY.copy()
|
||||
mock_entry[CONF_HOST] = "0.0.0.0"
|
||||
flow = init_config_flow(hass)
|
||||
result = await flow.async_step_user(mock_entry)
|
||||
|
||||
assert result["type"] == "form"
|
||||
assert result["errors"] == {CONF_NAME: "name_exists"}
|
||||
|
||||
|
||||
async def test_error_on_wrong_credentials(hass, auth_error):
|
||||
|
123
tests/components/transmission/test_init.py
Normal file
123
tests/components/transmission/test_init.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""Tests for Transmission init."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from transmissionrpc.error import TransmissionError
|
||||
|
||||
from homeassistant.components import transmission
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry, mock_coro
|
||||
|
||||
MOCK_ENTRY = MockConfigEntry(
|
||||
domain=transmission.DOMAIN,
|
||||
data={
|
||||
transmission.CONF_NAME: "Transmission",
|
||||
transmission.CONF_HOST: "0.0.0.0",
|
||||
transmission.CONF_USERNAME: "user",
|
||||
transmission.CONF_PASSWORD: "pass",
|
||||
transmission.CONF_PORT: 9091,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="api")
|
||||
def mock_transmission_api():
|
||||
"""Mock an api."""
|
||||
with patch("transmissionrpc.Client"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(name="auth_error")
|
||||
def mock_api_authentication_error():
|
||||
"""Mock an api."""
|
||||
with patch(
|
||||
"transmissionrpc.Client", side_effect=TransmissionError("401: Unauthorized")
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(name="unknown_error")
|
||||
def mock_api_unknown_error():
|
||||
"""Mock an api."""
|
||||
with patch("transmissionrpc.Client", side_effect=TransmissionError):
|
||||
yield
|
||||
|
||||
|
||||
async def test_setup_with_no_config(hass):
|
||||
"""Test that we do not discover anything or try to set up a Transmission client."""
|
||||
assert await async_setup_component(hass, transmission.DOMAIN, {}) is True
|
||||
assert transmission.DOMAIN not in hass.data
|
||||
|
||||
|
||||
async def test_setup_with_config(hass, api):
|
||||
"""Test that we import the config and setup the client."""
|
||||
config = {
|
||||
transmission.DOMAIN: {
|
||||
transmission.CONF_NAME: "Transmission",
|
||||
transmission.CONF_HOST: "0.0.0.0",
|
||||
transmission.CONF_USERNAME: "user",
|
||||
transmission.CONF_PASSWORD: "pass",
|
||||
transmission.CONF_PORT: 9091,
|
||||
},
|
||||
transmission.DOMAIN: {
|
||||
transmission.CONF_NAME: "Transmission2",
|
||||
transmission.CONF_HOST: "0.0.0.1",
|
||||
transmission.CONF_USERNAME: "user",
|
||||
transmission.CONF_PASSWORD: "pass",
|
||||
transmission.CONF_PORT: 9091,
|
||||
},
|
||||
}
|
||||
assert await async_setup_component(hass, transmission.DOMAIN, config) is True
|
||||
|
||||
|
||||
async def test_successful_config_entry(hass, api):
|
||||
"""Test that configured transmission is configured successfully."""
|
||||
|
||||
entry = MOCK_ENTRY
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
assert await transmission.async_setup_entry(hass, entry) is True
|
||||
assert entry.options == {
|
||||
transmission.CONF_SCAN_INTERVAL: transmission.DEFAULT_SCAN_INTERVAL
|
||||
}
|
||||
|
||||
|
||||
async def test_setup_failed(hass):
|
||||
"""Test transmission failed due to an error."""
|
||||
|
||||
entry = MOCK_ENTRY
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
# test connection error raising ConfigEntryNotReady
|
||||
with patch(
|
||||
"transmissionrpc.Client",
|
||||
side_effect=TransmissionError("111: Connection refused"),
|
||||
), pytest.raises(ConfigEntryNotReady):
|
||||
|
||||
await transmission.async_setup_entry(hass, entry)
|
||||
|
||||
# test Authentication error returning false
|
||||
|
||||
with patch(
|
||||
"transmissionrpc.Client", side_effect=TransmissionError("401: Unauthorized")
|
||||
):
|
||||
|
||||
assert await transmission.async_setup_entry(hass, entry) is False
|
||||
|
||||
|
||||
async def test_unload_entry(hass, api):
|
||||
"""Test removing transmission client."""
|
||||
entry = MOCK_ENTRY
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch.object(
|
||||
hass.config_entries, "async_forward_entry_unload", return_value=mock_coro(True)
|
||||
) as unload_entry:
|
||||
assert await transmission.async_setup_entry(hass, entry)
|
||||
|
||||
assert await transmission.async_unload_entry(hass, entry)
|
||||
assert unload_entry.call_count == 2
|
||||
assert entry.entry_id not in hass.data[transmission.DOMAIN]
|
Loading…
x
Reference in New Issue
Block a user