Allow multiple Transmission clients and add unique_id to entities (#28136)

* Allow multiple clients + improvements

* remove commented code

* fixed test_init.py
This commit is contained in:
Rami Mosleh 2019-10-23 23:09:11 +03:00 committed by Paulus Schoutsen
parent 062ec8a7c2
commit 7cb6607b1f
10 changed files with 333 additions and 183 deletions

View File

@ -703,7 +703,6 @@ omit =
homeassistant/components/tradfri/base_class.py
homeassistant/components/trafikverket_train/sensor.py
homeassistant/components/trafikverket_weatherstation/sensor.py
homeassistant/components/transmission/__init__.py
homeassistant/components/transmission/sensor.py
homeassistant/components/transmission/switch.py
homeassistant/components/transmission/const.py

View File

@ -1,39 +1,34 @@
{
"config": {
"abort": {
"one_instance_allowed": "Only a single instance is necessary."
},
"error": {
"cannot_connect": "Unable to Connect to host",
"wrong_credentials": "Wrong username or password"
},
"title": "Transmission",
"step": {
"options": {
"data": {
"scan_interval": "Update frequency"
},
"title": "Configure Options"
},
"user": {
"title": "Setup Transmission Client",
"data": {
"host": "Host",
"name": "Name",
"host": "Host",
"username": "Username",
"password": "Password",
"port": "Port",
"username": "Username"
},
"title": "Setup Transmission Client"
"port": "Port"
}
}
},
"title": "Transmission"
"error": {
"name_exists": "Name already exists",
"wrong_credentials": "Wrong username or password",
"cannot_connect": "Unable to Connect to host"
},
"abort": {
"already_configured": "Host is already configured."
}
},
"options": {
"step": {
"init": {
"title": "Configure options for Transmission",
"data": {
"scan_interval": "Update frequency"
},
"description": "Configure options for Transmission"
}
}
}
}

View File

@ -19,11 +19,10 @@ from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.util import slugify
from .const import (
ATTR_TORRENT,
DATA_TRANSMISSION,
DATA_UPDATED,
DEFAULT_NAME,
DEFAULT_PORT,
DEFAULT_SCAN_INTERVAL,
@ -37,74 +36,77 @@ _LOGGER = logging.getLogger(__name__)
SERVICE_ADD_TORRENT_SCHEMA = vol.Schema({vol.Required(ATTR_TORRENT): cv.string})
TRANS_SCHEMA = vol.All(
vol.Schema(
{
vol.Required(CONF_HOST): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
): cv.time_period,
}
)
)
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{
vol.Required(CONF_HOST): cv.string,
vol.Optional(CONF_PASSWORD): cv.string,
vol.Optional(CONF_USERNAME): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(
CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL
): cv.time_period,
}
)
},
extra=vol.ALLOW_EXTRA,
{DOMAIN: vol.All(cv.ensure_list, [TRANS_SCHEMA])}, extra=vol.ALLOW_EXTRA
)
async def async_setup(hass, config):
"""Import the Transmission Component from config."""
if not hass.config_entries.async_entries(DOMAIN) and DOMAIN in config:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=config[DOMAIN]
if DOMAIN in config:
for entry in config[DOMAIN]:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=entry
)
)
)
return True
async def async_setup_entry(hass, config_entry):
"""Set up the Transmission Component."""
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}
if not config_entry.options:
await async_populate_options(hass, config_entry)
client = TransmissionClient(hass, config_entry)
client_id = config_entry.entry_id
hass.data[DOMAIN][client_id] = client
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client
if not await client.async_setup():
return False
return True
async def async_unload_entry(hass, entry):
async def async_unload_entry(hass, config_entry):
"""Unload Transmission Entry from config_entry."""
hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT)
if hass.data[DOMAIN][entry.entry_id].unsub_timer:
hass.data[DOMAIN][entry.entry_id].unsub_timer()
client = hass.data[DOMAIN][config_entry.entry_id]
hass.services.async_remove(DOMAIN, client.service_name)
if client.unsub_timer:
client.unsub_timer()
for component in "sensor", "switch":
await hass.config_entries.async_forward_entry_unload(entry, component)
await hass.config_entries.async_forward_entry_unload(config_entry, component)
del hass.data[DOMAIN]
hass.data[DOMAIN].pop(config_entry.entry_id)
return True
async def get_api(hass, host, port, username=None, password=None):
async def get_api(hass, entry):
"""Get Transmission client."""
host = entry[CONF_HOST]
port = entry[CONF_PORT]
username = entry.get(CONF_USERNAME)
password = entry.get(CONF_PASSWORD)
try:
api = await hass.async_add_executor_job(
transmissionrpc.Client, host, port, username, password
)
_LOGGER.debug("Successfully connected to %s", host)
return api
except TransmissionError as error:
@ -112,20 +114,13 @@ async def get_api(hass, host, port, username=None, password=None):
_LOGGER.error("Credentials for Transmission client are not valid")
raise AuthenticationError
if "111: Connection refused" in str(error):
_LOGGER.error("Connecting to the Transmission client failed")
_LOGGER.error("Connecting to the Transmission client %s failed", host)
raise CannotConnect
_LOGGER.error(error)
raise UnknownError
async def async_populate_options(hass, config_entry):
"""Populate default options for Transmission Client."""
options = {CONF_SCAN_INTERVAL: config_entry.data["options"][CONF_SCAN_INTERVAL]}
hass.config_entries.async_update_entry(config_entry, options=options)
class TransmissionClient:
"""Transmission Client Object."""
@ -133,33 +128,35 @@ class TransmissionClient:
"""Initialize the Transmission RPC API."""
self.hass = hass
self.config_entry = config_entry
self.scan_interval = self.config_entry.options[CONF_SCAN_INTERVAL]
self.tm_data = None
self._tm_data = None
self.unsub_timer = None
@property
def service_name(self):
"""Return the service name."""
return slugify(f"{SERVICE_ADD_TORRENT}_{self.config_entry.data[CONF_NAME]}")
@property
def api(self):
"""Return the tm_data object."""
return self._tm_data
async def async_setup(self):
"""Set up the Transmission client."""
config = {
CONF_HOST: self.config_entry.data[CONF_HOST],
CONF_PORT: self.config_entry.data[CONF_PORT],
CONF_USERNAME: self.config_entry.data.get(CONF_USERNAME),
CONF_PASSWORD: self.config_entry.data.get(CONF_PASSWORD),
}
try:
api = await get_api(self.hass, **config)
api = await get_api(self.hass, self.config_entry.data)
except CannotConnect:
raise ConfigEntryNotReady
except (AuthenticationError, UnknownError):
return False
self.tm_data = self.hass.data[DOMAIN][DATA_TRANSMISSION] = TransmissionData(
self.hass, self.config_entry, api
)
self._tm_data = TransmissionData(self.hass, self.config_entry, api)
await self.hass.async_add_executor_job(self.tm_data.init_torrent_list)
await self.hass.async_add_executor_job(self.tm_data.update)
self.set_scan_interval(self.scan_interval)
await self.hass.async_add_executor_job(self._tm_data.init_torrent_list)
await self.hass.async_add_executor_job(self._tm_data.update)
self.add_options()
self.set_scan_interval(self.config_entry.options[CONF_SCAN_INTERVAL])
for platform in ["sensor", "switch"]:
self.hass.async_create_task(
@ -181,19 +178,31 @@ class TransmissionClient:
)
self.hass.services.async_register(
DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
DOMAIN, self.service_name, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA
)
self.config_entry.add_update_listener(self.async_options_updated)
return True
def add_options(self):
"""Add options for entry."""
if not self.config_entry.options:
scan_interval = self.config_entry.data.pop(
CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
)
options = {CONF_SCAN_INTERVAL: scan_interval}
self.hass.config_entries.async_update_entry(
self.config_entry, options=options
)
def set_scan_interval(self, scan_interval):
"""Update scan interval."""
def refresh(event_time):
async def refresh(event_time):
"""Get the latest data from Transmission."""
self.tm_data.update()
self._tm_data.update()
if self.unsub_timer is not None:
self.unsub_timer()
@ -215,6 +224,7 @@ class TransmissionData:
def __init__(self, hass, config, api):
"""Initialize the Transmission RPC API."""
self.hass = hass
self.config = config
self.data = None
self.torrents = None
self.session = None
@ -223,6 +233,16 @@ class TransmissionData:
self.completed_torrents = []
self.started_torrents = []
@property
def host(self):
"""Return the host name."""
return self.config.data[CONF_HOST]
@property
def signal_options_update(self):
"""Option update signal per transmission entry."""
return f"tm-options-{self.host}"
def update(self):
"""Get the latest data from Transmission instance."""
try:
@ -232,14 +252,13 @@ class TransmissionData:
self.check_completed_torrent()
self.check_started_torrent()
_LOGGER.debug("Torrent Data Updated")
_LOGGER.debug("Torrent Data for %s Updated", self.host)
self.available = True
except TransmissionError:
self.available = False
_LOGGER.error("Unable to connect to Transmission client")
dispatcher_send(self.hass, DATA_UPDATED)
_LOGGER.error("Unable to connect to Transmission client %s", self.host)
dispatcher_send(self.hass, self.signal_options_update)
def init_torrent_list(self):
"""Initialize torrent lists."""

View File

@ -29,32 +29,32 @@ class TransmissionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
"""Get the options flow for this handler."""
return TransmissionOptionsFlowHandler(config_entry)
def __init__(self):
"""Initialize the Transmission flow."""
self.config = {}
self.errors = {}
async def async_step_user(self, user_input=None):
"""Handle a flow initialized by the user."""
if self.hass.config_entries.async_entries(DOMAIN):
return self.async_abort(reason="one_instance_allowed")
errors = {}
if user_input is not None:
self.config[CONF_NAME] = user_input.pop(CONF_NAME)
for entry in self.hass.config_entries.async_entries(DOMAIN):
if entry.data[CONF_HOST] == user_input[CONF_HOST]:
return self.async_abort(reason="already_configured")
if entry.data[CONF_NAME] == user_input[CONF_NAME]:
errors[CONF_NAME] = "name_exists"
break
try:
await get_api(self.hass, **user_input)
self.config.update(user_input)
if "options" not in self.config:
self.config["options"] = {CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL}
return self.async_create_entry(
title=self.config[CONF_NAME], data=self.config
)
await get_api(self.hass, user_input)
except AuthenticationError:
self.errors[CONF_USERNAME] = "wrong_credentials"
self.errors[CONF_PASSWORD] = "wrong_credentials"
errors[CONF_USERNAME] = "wrong_credentials"
errors[CONF_PASSWORD] = "wrong_credentials"
except (CannotConnect, UnknownError):
self.errors["base"] = "cannot_connect"
errors["base"] = "cannot_connect"
if not errors:
return self.async_create_entry(
title=user_input[CONF_NAME], data=user_input
)
return self.async_show_form(
step_id="user",
@ -67,15 +67,12 @@ class TransmissionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
vol.Required(CONF_PORT, default=DEFAULT_PORT): int,
}
),
errors=self.errors,
errors=errors,
)
async def async_step_import(self, import_config):
"""Import from Transmission client config."""
self.config["options"] = {
CONF_SCAN_INTERVAL: import_config.pop(CONF_SCAN_INTERVAL).seconds
}
import_config[CONF_SCAN_INTERVAL] = import_config[CONF_SCAN_INTERVAL].seconds
return await self.async_step_user(user_input=import_config)
@ -95,8 +92,7 @@ class TransmissionOptionsFlowHandler(config_entries.OptionsFlow):
vol.Optional(
CONF_SCAN_INTERVAL,
default=self.config_entry.options.get(
CONF_SCAN_INTERVAL,
self.config_entry.data["options"][CONF_SCAN_INTERVAL],
CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL
),
): int
}

View File

@ -21,4 +21,3 @@ ATTR_TORRENT = "torrent"
SERVICE_ADD_TORRENT = "add_torrent"
DATA_UPDATED = "transmission_data_updated"
DATA_TRANSMISSION = "data_transmission"

View File

@ -6,7 +6,7 @@ from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity
from .const import DATA_TRANSMISSION, DATA_UPDATED, DOMAIN, SENSOR_TYPES
from .const import DOMAIN, SENSOR_TYPES
_LOGGER = logging.getLogger(__name__)
@ -19,7 +19,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up the Transmission sensors."""
transmission_api = hass.data[DOMAIN][DATA_TRANSMISSION]
tm_client = hass.data[DOMAIN][config_entry.entry_id]
name = config_entry.data[CONF_NAME]
dev = []
@ -27,7 +27,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
dev.append(
TransmissionSensor(
sensor_type,
transmission_api,
tm_client,
name,
SENSOR_TYPES[sensor_type][0],
SENSOR_TYPES[sensor_type][1],
@ -41,17 +41,12 @@ class TransmissionSensor(Entity):
"""Representation of a Transmission sensor."""
def __init__(
self,
sensor_type,
transmission_api,
client_name,
sensor_name,
unit_of_measurement,
self, sensor_type, tm_client, client_name, sensor_name, unit_of_measurement
):
"""Initialize the sensor."""
self._name = sensor_name
self._state = None
self._transmission_api = transmission_api
self._tm_client = tm_client
self._unit_of_measurement = unit_of_measurement
self._data = None
self.client_name = client_name
@ -62,6 +57,11 @@ class TransmissionSensor(Entity):
"""Return the name of the sensor."""
return f"{self.client_name} {self._name}"
@property
def unique_id(self):
"""Return the unique id of the entity."""
return f"{self._tm_client.api.host}-{self.name}"
@property
def state(self):
"""Return the state of the sensor."""
@ -80,12 +80,14 @@ class TransmissionSensor(Entity):
@property
def available(self):
"""Could the device be accessed during the last update call."""
return self._transmission_api.available
return self._tm_client.api.available
async def async_added_to_hass(self):
"""Handle entity which will be added."""
async_dispatcher_connect(
self.hass, DATA_UPDATED, self._schedule_immediate_update
self.hass,
self._tm_client.api.signal_options_update,
self._schedule_immediate_update,
)
@callback
@ -94,12 +96,12 @@ class TransmissionSensor(Entity):
def update(self):
"""Get the latest data from Transmission and updates the state."""
self._data = self._transmission_api.data
self._data = self._tm_client.api.data
if self.type == "completed_torrents":
self._state = self._transmission_api.get_completed_torrent_count()
self._state = self._tm_client.api.get_completed_torrent_count()
elif self.type == "started_torrents":
self._state = self._transmission_api.get_started_torrent_count()
self._state = self._tm_client.api.get_started_torrent_count()
if self.type == "current_status":
if self._data:

View File

@ -11,30 +11,25 @@
"password": "Password",
"port": "Port"
}
},
"options": {
"title": "Configure Options",
"data": {
"scan_interval": "Update frequency"
}
}
},
"error": {
"name_exists": "Name already exists",
"wrong_credentials": "Wrong username or password",
"cannot_connect": "Unable to Connect to host"
},
"abort": {
"one_instance_allowed": "Only a single instance is necessary."
"already_configured": "Host is already configured."
}
},
"options": {
"step": {
"init": {
"description": "Configure options for Transmission",
"title": "Configure options for Transmission",
"data": {
"scan_interval": "Update frequency"
}
}
}
}
}
}

View File

@ -6,7 +6,7 @@ from homeassistant.core import callback
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import ToggleEntity
from .const import DATA_TRANSMISSION, DATA_UPDATED, DOMAIN, SWITCH_TYPES
from .const import DOMAIN, SWITCH_TYPES
_LOGGING = logging.getLogger(__name__)
@ -19,12 +19,12 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info=
async def async_setup_entry(hass, config_entry, async_add_entities):
"""Set up the Transmission switch."""
transmission_api = hass.data[DOMAIN][DATA_TRANSMISSION]
tm_client = hass.data[DOMAIN][config_entry.entry_id]
name = config_entry.data[CONF_NAME]
dev = []
for switch_type, switch_name in SWITCH_TYPES.items():
dev.append(TransmissionSwitch(switch_type, switch_name, transmission_api, name))
dev.append(TransmissionSwitch(switch_type, switch_name, tm_client, name))
async_add_entities(dev, True)
@ -32,12 +32,12 @@ async def async_setup_entry(hass, config_entry, async_add_entities):
class TransmissionSwitch(ToggleEntity):
"""Representation of a Transmission switch."""
def __init__(self, switch_type, switch_name, transmission_api, name):
def __init__(self, switch_type, switch_name, tm_client, name):
"""Initialize the Transmission switch."""
self._name = switch_name
self.client_name = name
self.type = switch_type
self._transmission_api = transmission_api
self._tm_client = tm_client
self._state = STATE_OFF
self._data = None
@ -46,6 +46,11 @@ class TransmissionSwitch(ToggleEntity):
"""Return the name of the switch."""
return f"{self.client_name} {self._name}"
@property
def unique_id(self):
"""Return the unique id of the entity."""
return f"{self._tm_client.api.host}-{self.name}"
@property
def state(self):
"""Return the state of the device."""
@ -64,32 +69,34 @@ class TransmissionSwitch(ToggleEntity):
@property
def available(self):
"""Could the device be accessed during the last update call."""
return self._transmission_api.available
return self._tm_client.api.available
def turn_on(self, **kwargs):
"""Turn the device on."""
if self.type == "on_off":
_LOGGING.debug("Starting all torrents")
self._transmission_api.start_torrents()
self._tm_client.api.start_torrents()
elif self.type == "turtle_mode":
_LOGGING.debug("Turning Turtle Mode of Transmission on")
self._transmission_api.set_alt_speed_enabled(True)
self._transmission_api.update()
self._tm_client.api.set_alt_speed_enabled(True)
self._tm_client.api.update()
def turn_off(self, **kwargs):
"""Turn the device off."""
if self.type == "on_off":
_LOGGING.debug("Stoping all torrents")
self._transmission_api.stop_torrents()
self._tm_client.api.stop_torrents()
if self.type == "turtle_mode":
_LOGGING.debug("Turning Turtle Mode of Transmission off")
self._transmission_api.set_alt_speed_enabled(False)
self._transmission_api.update()
self._tm_client.api.set_alt_speed_enabled(False)
self._tm_client.api.update()
async def async_added_to_hass(self):
"""Handle entity which will be added."""
async_dispatcher_connect(
self.hass, DATA_UPDATED, self._schedule_immediate_update
self.hass,
self._tm_client.api.signal_options_update,
self._schedule_immediate_update,
)
@callback
@ -100,12 +107,12 @@ class TransmissionSwitch(ToggleEntity):
"""Get the latest data from Transmission and updates the state."""
active = None
if self.type == "on_off":
self._data = self._transmission_api.data
self._data = self._tm_client.api.data
if self._data:
active = self._data.activeTorrentCount > 0
elif self.type == "turtle_mode":
active = self._transmission_api.get_alt_speed_enabled()
active = self._tm_client.api.get_alt_speed_enabled()
if active is None:
return

View File

@ -1,4 +1,4 @@
"""Tests for Met.no config flow."""
"""Tests for Transmission config flow."""
from datetime import timedelta
from unittest.mock import patch
@ -31,6 +31,14 @@ PASSWORD = "password"
PORT = 9091
SCAN_INTERVAL = 10
MOCK_ENTRY = {
CONF_NAME: NAME,
CONF_HOST: HOST,
CONF_USERNAME: USERNAME,
CONF_PASSWORD: PASSWORD,
CONF_PORT: PORT,
}
@pytest.fixture(name="api")
def mock_transmission_api():
@ -90,18 +98,10 @@ async def test_flow_works(hass, api):
assert result["data"][CONF_NAME] == NAME
assert result["data"][CONF_HOST] == HOST
assert result["data"][CONF_PORT] == PORT
assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
# assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
# test with all provided
result = await flow.async_step_user(
{
CONF_NAME: NAME,
CONF_HOST: HOST,
CONF_USERNAME: USERNAME,
CONF_PASSWORD: PASSWORD,
CONF_PORT: PORT,
}
)
result = await flow.async_step_user(MOCK_ENTRY)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == NAME
@ -110,7 +110,7 @@ async def test_flow_works(hass, api):
assert result["data"][CONF_USERNAME] == USERNAME
assert result["data"][CONF_PASSWORD] == PASSWORD
assert result["data"][CONF_PORT] == PORT
assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
# assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
async def test_options(hass):
@ -118,14 +118,7 @@ async def test_options(hass):
entry = MockConfigEntry(
domain=DOMAIN,
title=CONF_NAME,
data={
"name": DEFAULT_NAME,
"host": HOST,
"username": USERNAME,
"password": PASSWORD,
"port": DEFAULT_PORT,
"options": {CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
},
data=MOCK_ENTRY,
options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
)
flow = init_config_flow(hass)
@ -157,7 +150,7 @@ async def test_import(hass, api):
assert result["data"][CONF_NAME] == DEFAULT_NAME
assert result["data"][CONF_HOST] == HOST
assert result["data"][CONF_PORT] == DEFAULT_PORT
assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
assert result["data"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL
# import with all
result = await flow.async_step_import(
@ -177,18 +170,40 @@ async def test_import(hass, api):
assert result["data"][CONF_USERNAME] == USERNAME
assert result["data"][CONF_PASSWORD] == PASSWORD
assert result["data"][CONF_PORT] == PORT
assert result["data"]["options"][CONF_SCAN_INTERVAL] == SCAN_INTERVAL
assert result["data"][CONF_SCAN_INTERVAL] == SCAN_INTERVAL
async def test_integration_already_exists(hass, api):
"""Test we only allow a single config flow."""
MockConfigEntry(domain=DOMAIN).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": "user"}
async def test_host_already_configured(hass, api):
"""Test host is already configured."""
entry = MockConfigEntry(
domain=DOMAIN,
data=MOCK_ENTRY,
options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
)
entry.add_to_hass(hass)
flow = init_config_flow(hass)
result = await flow.async_step_user(MOCK_ENTRY)
assert result["type"] == "abort"
assert result["reason"] == "one_instance_allowed"
assert result["reason"] == "already_configured"
async def test_name_already_configured(hass, api):
"""Test name is already configured."""
entry = MockConfigEntry(
domain=DOMAIN,
data=MOCK_ENTRY,
options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL},
)
entry.add_to_hass(hass)
mock_entry = MOCK_ENTRY.copy()
mock_entry[CONF_HOST] = "0.0.0.0"
flow = init_config_flow(hass)
result = await flow.async_step_user(mock_entry)
assert result["type"] == "form"
assert result["errors"] == {CONF_NAME: "name_exists"}
async def test_error_on_wrong_credentials(hass, auth_error):

View File

@ -0,0 +1,123 @@
"""Tests for Transmission init."""
from unittest.mock import patch
import pytest
from transmissionrpc.error import TransmissionError
from homeassistant.components import transmission
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry, mock_coro
MOCK_ENTRY = MockConfigEntry(
domain=transmission.DOMAIN,
data={
transmission.CONF_NAME: "Transmission",
transmission.CONF_HOST: "0.0.0.0",
transmission.CONF_USERNAME: "user",
transmission.CONF_PASSWORD: "pass",
transmission.CONF_PORT: 9091,
},
)
@pytest.fixture(name="api")
def mock_transmission_api():
"""Mock an api."""
with patch("transmissionrpc.Client"):
yield
@pytest.fixture(name="auth_error")
def mock_api_authentication_error():
"""Mock an api."""
with patch(
"transmissionrpc.Client", side_effect=TransmissionError("401: Unauthorized")
):
yield
@pytest.fixture(name="unknown_error")
def mock_api_unknown_error():
"""Mock an api."""
with patch("transmissionrpc.Client", side_effect=TransmissionError):
yield
async def test_setup_with_no_config(hass):
"""Test that we do not discover anything or try to set up a Transmission client."""
assert await async_setup_component(hass, transmission.DOMAIN, {}) is True
assert transmission.DOMAIN not in hass.data
async def test_setup_with_config(hass, api):
"""Test that we import the config and setup the client."""
config = {
transmission.DOMAIN: {
transmission.CONF_NAME: "Transmission",
transmission.CONF_HOST: "0.0.0.0",
transmission.CONF_USERNAME: "user",
transmission.CONF_PASSWORD: "pass",
transmission.CONF_PORT: 9091,
},
transmission.DOMAIN: {
transmission.CONF_NAME: "Transmission2",
transmission.CONF_HOST: "0.0.0.1",
transmission.CONF_USERNAME: "user",
transmission.CONF_PASSWORD: "pass",
transmission.CONF_PORT: 9091,
},
}
assert await async_setup_component(hass, transmission.DOMAIN, config) is True
async def test_successful_config_entry(hass, api):
"""Test that configured transmission is configured successfully."""
entry = MOCK_ENTRY
entry.add_to_hass(hass)
assert await transmission.async_setup_entry(hass, entry) is True
assert entry.options == {
transmission.CONF_SCAN_INTERVAL: transmission.DEFAULT_SCAN_INTERVAL
}
async def test_setup_failed(hass):
"""Test transmission failed due to an error."""
entry = MOCK_ENTRY
entry.add_to_hass(hass)
# test connection error raising ConfigEntryNotReady
with patch(
"transmissionrpc.Client",
side_effect=TransmissionError("111: Connection refused"),
), pytest.raises(ConfigEntryNotReady):
await transmission.async_setup_entry(hass, entry)
# test Authentication error returning false
with patch(
"transmissionrpc.Client", side_effect=TransmissionError("401: Unauthorized")
):
assert await transmission.async_setup_entry(hass, entry) is False
async def test_unload_entry(hass, api):
"""Test removing transmission client."""
entry = MOCK_ENTRY
entry.add_to_hass(hass)
with patch.object(
hass.config_entries, "async_forward_entry_unload", return_value=mock_coro(True)
) as unload_entry:
assert await transmission.async_setup_entry(hass, entry)
assert await transmission.async_unload_entry(hass, entry)
assert unload_entry.call_count == 2
assert entry.entry_id not in hass.data[transmission.DOMAIN]