From 7cb6607b1f6a86469cb79314594a2920141ac224 Mon Sep 17 00:00:00 2001 From: Rami Mosleh Date: Wed, 23 Oct 2019 23:09:11 +0300 Subject: [PATCH] Allow multiple Transmission clients and add unique_id to entities (#28136) * Allow multiple clients + improvements * remove commented code * fixed test_init.py --- .coveragerc | 1 - .../transmission/.translations/en.json | 37 ++--- .../components/transmission/__init__.py | 153 ++++++++++-------- .../components/transmission/config_flow.py | 46 +++--- .../components/transmission/const.py | 1 - .../components/transmission/sensor.py | 32 ++-- .../components/transmission/strings.json | 13 +- .../components/transmission/switch.py | 37 +++-- .../transmission/test_config_flow.py | 73 +++++---- tests/components/transmission/test_init.py | 123 ++++++++++++++ 10 files changed, 333 insertions(+), 183 deletions(-) create mode 100644 tests/components/transmission/test_init.py diff --git a/.coveragerc b/.coveragerc index 3645eb00d33..f97a7524a21 100644 --- a/.coveragerc +++ b/.coveragerc @@ -703,7 +703,6 @@ omit = homeassistant/components/tradfri/base_class.py homeassistant/components/trafikverket_train/sensor.py homeassistant/components/trafikverket_weatherstation/sensor.py - homeassistant/components/transmission/__init__.py homeassistant/components/transmission/sensor.py homeassistant/components/transmission/switch.py homeassistant/components/transmission/const.py diff --git a/homeassistant/components/transmission/.translations/en.json b/homeassistant/components/transmission/.translations/en.json index 67461d1a3e8..45c16be36e2 100644 --- a/homeassistant/components/transmission/.translations/en.json +++ b/homeassistant/components/transmission/.translations/en.json @@ -1,39 +1,34 @@ { "config": { - "abort": { - "one_instance_allowed": "Only a single instance is necessary." - }, - "error": { - "cannot_connect": "Unable to Connect to host", - "wrong_credentials": "Wrong username or password" - }, + "title": "Transmission", "step": { - "options": { - "data": { - "scan_interval": "Update frequency" - }, - "title": "Configure Options" - }, "user": { + "title": "Setup Transmission Client", "data": { - "host": "Host", "name": "Name", + "host": "Host", + "username": "Username", "password": "Password", - "port": "Port", - "username": "Username" - }, - "title": "Setup Transmission Client" + "port": "Port" + } } }, - "title": "Transmission" + "error": { + "name_exists": "Name already exists", + "wrong_credentials": "Wrong username or password", + "cannot_connect": "Unable to Connect to host" + }, + "abort": { + "already_configured": "Host is already configured." + } }, "options": { "step": { "init": { + "title": "Configure options for Transmission", "data": { "scan_interval": "Update frequency" - }, - "description": "Configure options for Transmission" + } } } } diff --git a/homeassistant/components/transmission/__init__.py b/homeassistant/components/transmission/__init__.py index e6ddd87bdf5..6cfd6bf640a 100644 --- a/homeassistant/components/transmission/__init__.py +++ b/homeassistant/components/transmission/__init__.py @@ -19,11 +19,10 @@ from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.helpers import config_validation as cv from homeassistant.helpers.dispatcher import dispatcher_send from homeassistant.helpers.event import async_track_time_interval +from homeassistant.util import slugify from .const import ( ATTR_TORRENT, - DATA_TRANSMISSION, - DATA_UPDATED, DEFAULT_NAME, DEFAULT_PORT, DEFAULT_SCAN_INTERVAL, @@ -37,74 +36,77 @@ _LOGGER = logging.getLogger(__name__) SERVICE_ADD_TORRENT_SCHEMA = vol.Schema({vol.Required(ATTR_TORRENT): cv.string}) +TRANS_SCHEMA = vol.All( + vol.Schema( + { + vol.Required(CONF_HOST): cv.string, + vol.Optional(CONF_PASSWORD): cv.string, + vol.Optional(CONF_USERNAME): cv.string, + vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, + vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, + vol.Optional( + CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL + ): cv.time_period, + } + ) +) + CONFIG_SCHEMA = vol.Schema( - { - DOMAIN: vol.Schema( - { - vol.Required(CONF_HOST): cv.string, - vol.Optional(CONF_PASSWORD): cv.string, - vol.Optional(CONF_USERNAME): cv.string, - vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, - vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, - vol.Optional( - CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL - ): cv.time_period, - } - ) - }, - extra=vol.ALLOW_EXTRA, + {DOMAIN: vol.All(cv.ensure_list, [TRANS_SCHEMA])}, extra=vol.ALLOW_EXTRA ) async def async_setup(hass, config): """Import the Transmission Component from config.""" - if not hass.config_entries.async_entries(DOMAIN) and DOMAIN in config: - hass.async_create_task( - hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_IMPORT}, data=config[DOMAIN] + if DOMAIN in config: + for entry in config[DOMAIN]: + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_IMPORT}, data=entry + ) ) - ) return True async def async_setup_entry(hass, config_entry): """Set up the Transmission Component.""" - if DOMAIN not in hass.data: - hass.data[DOMAIN] = {} - - if not config_entry.options: - await async_populate_options(hass, config_entry) - client = TransmissionClient(hass, config_entry) - client_id = config_entry.entry_id - hass.data[DOMAIN][client_id] = client + hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client + if not await client.async_setup(): return False return True -async def async_unload_entry(hass, entry): +async def async_unload_entry(hass, config_entry): """Unload Transmission Entry from config_entry.""" - hass.services.async_remove(DOMAIN, SERVICE_ADD_TORRENT) - if hass.data[DOMAIN][entry.entry_id].unsub_timer: - hass.data[DOMAIN][entry.entry_id].unsub_timer() + client = hass.data[DOMAIN][config_entry.entry_id] + hass.services.async_remove(DOMAIN, client.service_name) + if client.unsub_timer: + client.unsub_timer() for component in "sensor", "switch": - await hass.config_entries.async_forward_entry_unload(entry, component) + await hass.config_entries.async_forward_entry_unload(config_entry, component) - del hass.data[DOMAIN] + hass.data[DOMAIN].pop(config_entry.entry_id) return True -async def get_api(hass, host, port, username=None, password=None): +async def get_api(hass, entry): """Get Transmission client.""" + host = entry[CONF_HOST] + port = entry[CONF_PORT] + username = entry.get(CONF_USERNAME) + password = entry.get(CONF_PASSWORD) + try: api = await hass.async_add_executor_job( transmissionrpc.Client, host, port, username, password ) + _LOGGER.debug("Successfully connected to %s", host) return api except TransmissionError as error: @@ -112,20 +114,13 @@ async def get_api(hass, host, port, username=None, password=None): _LOGGER.error("Credentials for Transmission client are not valid") raise AuthenticationError if "111: Connection refused" in str(error): - _LOGGER.error("Connecting to the Transmission client failed") + _LOGGER.error("Connecting to the Transmission client %s failed", host) raise CannotConnect _LOGGER.error(error) raise UnknownError -async def async_populate_options(hass, config_entry): - """Populate default options for Transmission Client.""" - options = {CONF_SCAN_INTERVAL: config_entry.data["options"][CONF_SCAN_INTERVAL]} - - hass.config_entries.async_update_entry(config_entry, options=options) - - class TransmissionClient: """Transmission Client Object.""" @@ -133,33 +128,35 @@ class TransmissionClient: """Initialize the Transmission RPC API.""" self.hass = hass self.config_entry = config_entry - self.scan_interval = self.config_entry.options[CONF_SCAN_INTERVAL] - self.tm_data = None + self._tm_data = None self.unsub_timer = None + @property + def service_name(self): + """Return the service name.""" + return slugify(f"{SERVICE_ADD_TORRENT}_{self.config_entry.data[CONF_NAME]}") + + @property + def api(self): + """Return the tm_data object.""" + return self._tm_data + async def async_setup(self): """Set up the Transmission client.""" - config = { - CONF_HOST: self.config_entry.data[CONF_HOST], - CONF_PORT: self.config_entry.data[CONF_PORT], - CONF_USERNAME: self.config_entry.data.get(CONF_USERNAME), - CONF_PASSWORD: self.config_entry.data.get(CONF_PASSWORD), - } try: - api = await get_api(self.hass, **config) + api = await get_api(self.hass, self.config_entry.data) except CannotConnect: raise ConfigEntryNotReady except (AuthenticationError, UnknownError): return False - self.tm_data = self.hass.data[DOMAIN][DATA_TRANSMISSION] = TransmissionData( - self.hass, self.config_entry, api - ) + self._tm_data = TransmissionData(self.hass, self.config_entry, api) - await self.hass.async_add_executor_job(self.tm_data.init_torrent_list) - await self.hass.async_add_executor_job(self.tm_data.update) - self.set_scan_interval(self.scan_interval) + await self.hass.async_add_executor_job(self._tm_data.init_torrent_list) + await self.hass.async_add_executor_job(self._tm_data.update) + self.add_options() + self.set_scan_interval(self.config_entry.options[CONF_SCAN_INTERVAL]) for platform in ["sensor", "switch"]: self.hass.async_create_task( @@ -181,19 +178,31 @@ class TransmissionClient: ) self.hass.services.async_register( - DOMAIN, SERVICE_ADD_TORRENT, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA + DOMAIN, self.service_name, add_torrent, schema=SERVICE_ADD_TORRENT_SCHEMA ) self.config_entry.add_update_listener(self.async_options_updated) return True + def add_options(self): + """Add options for entry.""" + if not self.config_entry.options: + scan_interval = self.config_entry.data.pop( + CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL + ) + options = {CONF_SCAN_INTERVAL: scan_interval} + + self.hass.config_entries.async_update_entry( + self.config_entry, options=options + ) + def set_scan_interval(self, scan_interval): """Update scan interval.""" - def refresh(event_time): + async def refresh(event_time): """Get the latest data from Transmission.""" - self.tm_data.update() + self._tm_data.update() if self.unsub_timer is not None: self.unsub_timer() @@ -215,6 +224,7 @@ class TransmissionData: def __init__(self, hass, config, api): """Initialize the Transmission RPC API.""" self.hass = hass + self.config = config self.data = None self.torrents = None self.session = None @@ -223,6 +233,16 @@ class TransmissionData: self.completed_torrents = [] self.started_torrents = [] + @property + def host(self): + """Return the host name.""" + return self.config.data[CONF_HOST] + + @property + def signal_options_update(self): + """Option update signal per transmission entry.""" + return f"tm-options-{self.host}" + def update(self): """Get the latest data from Transmission instance.""" try: @@ -232,14 +252,13 @@ class TransmissionData: self.check_completed_torrent() self.check_started_torrent() - _LOGGER.debug("Torrent Data Updated") + _LOGGER.debug("Torrent Data for %s Updated", self.host) self.available = True except TransmissionError: self.available = False - _LOGGER.error("Unable to connect to Transmission client") - - dispatcher_send(self.hass, DATA_UPDATED) + _LOGGER.error("Unable to connect to Transmission client %s", self.host) + dispatcher_send(self.hass, self.signal_options_update) def init_torrent_list(self): """Initialize torrent lists.""" diff --git a/homeassistant/components/transmission/config_flow.py b/homeassistant/components/transmission/config_flow.py index 99376f4b6e0..d7b9efb15d8 100644 --- a/homeassistant/components/transmission/config_flow.py +++ b/homeassistant/components/transmission/config_flow.py @@ -29,32 +29,32 @@ class TransmissionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Get the options flow for this handler.""" return TransmissionOptionsFlowHandler(config_entry) - def __init__(self): - """Initialize the Transmission flow.""" - self.config = {} - self.errors = {} - async def async_step_user(self, user_input=None): """Handle a flow initialized by the user.""" - if self.hass.config_entries.async_entries(DOMAIN): - return self.async_abort(reason="one_instance_allowed") + errors = {} if user_input is not None: - self.config[CONF_NAME] = user_input.pop(CONF_NAME) + for entry in self.hass.config_entries.async_entries(DOMAIN): + if entry.data[CONF_HOST] == user_input[CONF_HOST]: + return self.async_abort(reason="already_configured") + if entry.data[CONF_NAME] == user_input[CONF_NAME]: + errors[CONF_NAME] = "name_exists" + break + try: - await get_api(self.hass, **user_input) - self.config.update(user_input) - if "options" not in self.config: - self.config["options"] = {CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL} - return self.async_create_entry( - title=self.config[CONF_NAME], data=self.config - ) + await get_api(self.hass, user_input) + except AuthenticationError: - self.errors[CONF_USERNAME] = "wrong_credentials" - self.errors[CONF_PASSWORD] = "wrong_credentials" + errors[CONF_USERNAME] = "wrong_credentials" + errors[CONF_PASSWORD] = "wrong_credentials" except (CannotConnect, UnknownError): - self.errors["base"] = "cannot_connect" + errors["base"] = "cannot_connect" + + if not errors: + return self.async_create_entry( + title=user_input[CONF_NAME], data=user_input + ) return self.async_show_form( step_id="user", @@ -67,15 +67,12 @@ class TransmissionFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): vol.Required(CONF_PORT, default=DEFAULT_PORT): int, } ), - errors=self.errors, + errors=errors, ) async def async_step_import(self, import_config): """Import from Transmission client config.""" - self.config["options"] = { - CONF_SCAN_INTERVAL: import_config.pop(CONF_SCAN_INTERVAL).seconds - } - + import_config[CONF_SCAN_INTERVAL] = import_config[CONF_SCAN_INTERVAL].seconds return await self.async_step_user(user_input=import_config) @@ -95,8 +92,7 @@ class TransmissionOptionsFlowHandler(config_entries.OptionsFlow): vol.Optional( CONF_SCAN_INTERVAL, default=self.config_entry.options.get( - CONF_SCAN_INTERVAL, - self.config_entry.data["options"][CONF_SCAN_INTERVAL], + CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL ), ): int } diff --git a/homeassistant/components/transmission/const.py b/homeassistant/components/transmission/const.py index e4a8b1490c2..472bb32a391 100644 --- a/homeassistant/components/transmission/const.py +++ b/homeassistant/components/transmission/const.py @@ -21,4 +21,3 @@ ATTR_TORRENT = "torrent" SERVICE_ADD_TORRENT = "add_torrent" DATA_UPDATED = "transmission_data_updated" -DATA_TRANSMISSION = "data_transmission" diff --git a/homeassistant/components/transmission/sensor.py b/homeassistant/components/transmission/sensor.py index 30dfa4a3cbe..d9fd2b51144 100644 --- a/homeassistant/components/transmission/sensor.py +++ b/homeassistant/components/transmission/sensor.py @@ -6,7 +6,7 @@ from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import Entity -from .const import DATA_TRANSMISSION, DATA_UPDATED, DOMAIN, SENSOR_TYPES +from .const import DOMAIN, SENSOR_TYPES _LOGGER = logging.getLogger(__name__) @@ -19,7 +19,7 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Transmission sensors.""" - transmission_api = hass.data[DOMAIN][DATA_TRANSMISSION] + tm_client = hass.data[DOMAIN][config_entry.entry_id] name = config_entry.data[CONF_NAME] dev = [] @@ -27,7 +27,7 @@ async def async_setup_entry(hass, config_entry, async_add_entities): dev.append( TransmissionSensor( sensor_type, - transmission_api, + tm_client, name, SENSOR_TYPES[sensor_type][0], SENSOR_TYPES[sensor_type][1], @@ -41,17 +41,12 @@ class TransmissionSensor(Entity): """Representation of a Transmission sensor.""" def __init__( - self, - sensor_type, - transmission_api, - client_name, - sensor_name, - unit_of_measurement, + self, sensor_type, tm_client, client_name, sensor_name, unit_of_measurement ): """Initialize the sensor.""" self._name = sensor_name self._state = None - self._transmission_api = transmission_api + self._tm_client = tm_client self._unit_of_measurement = unit_of_measurement self._data = None self.client_name = client_name @@ -62,6 +57,11 @@ class TransmissionSensor(Entity): """Return the name of the sensor.""" return f"{self.client_name} {self._name}" + @property + def unique_id(self): + """Return the unique id of the entity.""" + return f"{self._tm_client.api.host}-{self.name}" + @property def state(self): """Return the state of the sensor.""" @@ -80,12 +80,14 @@ class TransmissionSensor(Entity): @property def available(self): """Could the device be accessed during the last update call.""" - return self._transmission_api.available + return self._tm_client.api.available async def async_added_to_hass(self): """Handle entity which will be added.""" async_dispatcher_connect( - self.hass, DATA_UPDATED, self._schedule_immediate_update + self.hass, + self._tm_client.api.signal_options_update, + self._schedule_immediate_update, ) @callback @@ -94,12 +96,12 @@ class TransmissionSensor(Entity): def update(self): """Get the latest data from Transmission and updates the state.""" - self._data = self._transmission_api.data + self._data = self._tm_client.api.data if self.type == "completed_torrents": - self._state = self._transmission_api.get_completed_torrent_count() + self._state = self._tm_client.api.get_completed_torrent_count() elif self.type == "started_torrents": - self._state = self._transmission_api.get_started_torrent_count() + self._state = self._tm_client.api.get_started_torrent_count() if self.type == "current_status": if self._data: diff --git a/homeassistant/components/transmission/strings.json b/homeassistant/components/transmission/strings.json index 203ed07adb5..45c16be36e2 100644 --- a/homeassistant/components/transmission/strings.json +++ b/homeassistant/components/transmission/strings.json @@ -11,30 +11,25 @@ "password": "Password", "port": "Port" } - }, - "options": { - "title": "Configure Options", - "data": { - "scan_interval": "Update frequency" - } } }, "error": { + "name_exists": "Name already exists", "wrong_credentials": "Wrong username or password", "cannot_connect": "Unable to Connect to host" }, "abort": { - "one_instance_allowed": "Only a single instance is necessary." + "already_configured": "Host is already configured." } }, "options": { "step": { "init": { - "description": "Configure options for Transmission", + "title": "Configure options for Transmission", "data": { "scan_interval": "Update frequency" } } } } -} +} \ No newline at end of file diff --git a/homeassistant/components/transmission/switch.py b/homeassistant/components/transmission/switch.py index 0bb43f715ac..4b93b3f06e2 100644 --- a/homeassistant/components/transmission/switch.py +++ b/homeassistant/components/transmission/switch.py @@ -6,7 +6,7 @@ from homeassistant.core import callback from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity import ToggleEntity -from .const import DATA_TRANSMISSION, DATA_UPDATED, DOMAIN, SWITCH_TYPES +from .const import DOMAIN, SWITCH_TYPES _LOGGING = logging.getLogger(__name__) @@ -19,12 +19,12 @@ async def async_setup_platform(hass, config, async_add_entities, discovery_info= async def async_setup_entry(hass, config_entry, async_add_entities): """Set up the Transmission switch.""" - transmission_api = hass.data[DOMAIN][DATA_TRANSMISSION] + tm_client = hass.data[DOMAIN][config_entry.entry_id] name = config_entry.data[CONF_NAME] dev = [] for switch_type, switch_name in SWITCH_TYPES.items(): - dev.append(TransmissionSwitch(switch_type, switch_name, transmission_api, name)) + dev.append(TransmissionSwitch(switch_type, switch_name, tm_client, name)) async_add_entities(dev, True) @@ -32,12 +32,12 @@ async def async_setup_entry(hass, config_entry, async_add_entities): class TransmissionSwitch(ToggleEntity): """Representation of a Transmission switch.""" - def __init__(self, switch_type, switch_name, transmission_api, name): + def __init__(self, switch_type, switch_name, tm_client, name): """Initialize the Transmission switch.""" self._name = switch_name self.client_name = name self.type = switch_type - self._transmission_api = transmission_api + self._tm_client = tm_client self._state = STATE_OFF self._data = None @@ -46,6 +46,11 @@ class TransmissionSwitch(ToggleEntity): """Return the name of the switch.""" return f"{self.client_name} {self._name}" + @property + def unique_id(self): + """Return the unique id of the entity.""" + return f"{self._tm_client.api.host}-{self.name}" + @property def state(self): """Return the state of the device.""" @@ -64,32 +69,34 @@ class TransmissionSwitch(ToggleEntity): @property def available(self): """Could the device be accessed during the last update call.""" - return self._transmission_api.available + return self._tm_client.api.available def turn_on(self, **kwargs): """Turn the device on.""" if self.type == "on_off": _LOGGING.debug("Starting all torrents") - self._transmission_api.start_torrents() + self._tm_client.api.start_torrents() elif self.type == "turtle_mode": _LOGGING.debug("Turning Turtle Mode of Transmission on") - self._transmission_api.set_alt_speed_enabled(True) - self._transmission_api.update() + self._tm_client.api.set_alt_speed_enabled(True) + self._tm_client.api.update() def turn_off(self, **kwargs): """Turn the device off.""" if self.type == "on_off": _LOGGING.debug("Stoping all torrents") - self._transmission_api.stop_torrents() + self._tm_client.api.stop_torrents() if self.type == "turtle_mode": _LOGGING.debug("Turning Turtle Mode of Transmission off") - self._transmission_api.set_alt_speed_enabled(False) - self._transmission_api.update() + self._tm_client.api.set_alt_speed_enabled(False) + self._tm_client.api.update() async def async_added_to_hass(self): """Handle entity which will be added.""" async_dispatcher_connect( - self.hass, DATA_UPDATED, self._schedule_immediate_update + self.hass, + self._tm_client.api.signal_options_update, + self._schedule_immediate_update, ) @callback @@ -100,12 +107,12 @@ class TransmissionSwitch(ToggleEntity): """Get the latest data from Transmission and updates the state.""" active = None if self.type == "on_off": - self._data = self._transmission_api.data + self._data = self._tm_client.api.data if self._data: active = self._data.activeTorrentCount > 0 elif self.type == "turtle_mode": - active = self._transmission_api.get_alt_speed_enabled() + active = self._tm_client.api.get_alt_speed_enabled() if active is None: return diff --git a/tests/components/transmission/test_config_flow.py b/tests/components/transmission/test_config_flow.py index e79f5c8ac96..28fbed9ff42 100644 --- a/tests/components/transmission/test_config_flow.py +++ b/tests/components/transmission/test_config_flow.py @@ -1,4 +1,4 @@ -"""Tests for Met.no config flow.""" +"""Tests for Transmission config flow.""" from datetime import timedelta from unittest.mock import patch @@ -31,6 +31,14 @@ PASSWORD = "password" PORT = 9091 SCAN_INTERVAL = 10 +MOCK_ENTRY = { + CONF_NAME: NAME, + CONF_HOST: HOST, + CONF_USERNAME: USERNAME, + CONF_PASSWORD: PASSWORD, + CONF_PORT: PORT, +} + @pytest.fixture(name="api") def mock_transmission_api(): @@ -90,18 +98,10 @@ async def test_flow_works(hass, api): assert result["data"][CONF_NAME] == NAME assert result["data"][CONF_HOST] == HOST assert result["data"][CONF_PORT] == PORT - assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL + # assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL # test with all provided - result = await flow.async_step_user( - { - CONF_NAME: NAME, - CONF_HOST: HOST, - CONF_USERNAME: USERNAME, - CONF_PASSWORD: PASSWORD, - CONF_PORT: PORT, - } - ) + result = await flow.async_step_user(MOCK_ENTRY) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["title"] == NAME @@ -110,7 +110,7 @@ async def test_flow_works(hass, api): assert result["data"][CONF_USERNAME] == USERNAME assert result["data"][CONF_PASSWORD] == PASSWORD assert result["data"][CONF_PORT] == PORT - assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL + # assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL async def test_options(hass): @@ -118,14 +118,7 @@ async def test_options(hass): entry = MockConfigEntry( domain=DOMAIN, title=CONF_NAME, - data={ - "name": DEFAULT_NAME, - "host": HOST, - "username": USERNAME, - "password": PASSWORD, - "port": DEFAULT_PORT, - "options": {CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL}, - }, + data=MOCK_ENTRY, options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL}, ) flow = init_config_flow(hass) @@ -157,7 +150,7 @@ async def test_import(hass, api): assert result["data"][CONF_NAME] == DEFAULT_NAME assert result["data"][CONF_HOST] == HOST assert result["data"][CONF_PORT] == DEFAULT_PORT - assert result["data"]["options"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL + assert result["data"][CONF_SCAN_INTERVAL] == DEFAULT_SCAN_INTERVAL # import with all result = await flow.async_step_import( @@ -177,18 +170,40 @@ async def test_import(hass, api): assert result["data"][CONF_USERNAME] == USERNAME assert result["data"][CONF_PASSWORD] == PASSWORD assert result["data"][CONF_PORT] == PORT - assert result["data"]["options"][CONF_SCAN_INTERVAL] == SCAN_INTERVAL + assert result["data"][CONF_SCAN_INTERVAL] == SCAN_INTERVAL -async def test_integration_already_exists(hass, api): - """Test we only allow a single config flow.""" - MockConfigEntry(domain=DOMAIN).add_to_hass(hass) - - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": "user"} +async def test_host_already_configured(hass, api): + """Test host is already configured.""" + entry = MockConfigEntry( + domain=DOMAIN, + data=MOCK_ENTRY, + options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL}, ) + entry.add_to_hass(hass) + flow = init_config_flow(hass) + result = await flow.async_step_user(MOCK_ENTRY) + assert result["type"] == "abort" - assert result["reason"] == "one_instance_allowed" + assert result["reason"] == "already_configured" + + +async def test_name_already_configured(hass, api): + """Test name is already configured.""" + entry = MockConfigEntry( + domain=DOMAIN, + data=MOCK_ENTRY, + options={CONF_SCAN_INTERVAL: DEFAULT_SCAN_INTERVAL}, + ) + entry.add_to_hass(hass) + + mock_entry = MOCK_ENTRY.copy() + mock_entry[CONF_HOST] = "0.0.0.0" + flow = init_config_flow(hass) + result = await flow.async_step_user(mock_entry) + + assert result["type"] == "form" + assert result["errors"] == {CONF_NAME: "name_exists"} async def test_error_on_wrong_credentials(hass, auth_error): diff --git a/tests/components/transmission/test_init.py b/tests/components/transmission/test_init.py new file mode 100644 index 00000000000..4baa00de7a7 --- /dev/null +++ b/tests/components/transmission/test_init.py @@ -0,0 +1,123 @@ +"""Tests for Transmission init.""" + +from unittest.mock import patch + +import pytest +from transmissionrpc.error import TransmissionError + +from homeassistant.components import transmission +from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.setup import async_setup_component + +from tests.common import MockConfigEntry, mock_coro + +MOCK_ENTRY = MockConfigEntry( + domain=transmission.DOMAIN, + data={ + transmission.CONF_NAME: "Transmission", + transmission.CONF_HOST: "0.0.0.0", + transmission.CONF_USERNAME: "user", + transmission.CONF_PASSWORD: "pass", + transmission.CONF_PORT: 9091, + }, +) + + +@pytest.fixture(name="api") +def mock_transmission_api(): + """Mock an api.""" + with patch("transmissionrpc.Client"): + yield + + +@pytest.fixture(name="auth_error") +def mock_api_authentication_error(): + """Mock an api.""" + with patch( + "transmissionrpc.Client", side_effect=TransmissionError("401: Unauthorized") + ): + yield + + +@pytest.fixture(name="unknown_error") +def mock_api_unknown_error(): + """Mock an api.""" + with patch("transmissionrpc.Client", side_effect=TransmissionError): + yield + + +async def test_setup_with_no_config(hass): + """Test that we do not discover anything or try to set up a Transmission client.""" + assert await async_setup_component(hass, transmission.DOMAIN, {}) is True + assert transmission.DOMAIN not in hass.data + + +async def test_setup_with_config(hass, api): + """Test that we import the config and setup the client.""" + config = { + transmission.DOMAIN: { + transmission.CONF_NAME: "Transmission", + transmission.CONF_HOST: "0.0.0.0", + transmission.CONF_USERNAME: "user", + transmission.CONF_PASSWORD: "pass", + transmission.CONF_PORT: 9091, + }, + transmission.DOMAIN: { + transmission.CONF_NAME: "Transmission2", + transmission.CONF_HOST: "0.0.0.1", + transmission.CONF_USERNAME: "user", + transmission.CONF_PASSWORD: "pass", + transmission.CONF_PORT: 9091, + }, + } + assert await async_setup_component(hass, transmission.DOMAIN, config) is True + + +async def test_successful_config_entry(hass, api): + """Test that configured transmission is configured successfully.""" + + entry = MOCK_ENTRY + entry.add_to_hass(hass) + + assert await transmission.async_setup_entry(hass, entry) is True + assert entry.options == { + transmission.CONF_SCAN_INTERVAL: transmission.DEFAULT_SCAN_INTERVAL + } + + +async def test_setup_failed(hass): + """Test transmission failed due to an error.""" + + entry = MOCK_ENTRY + entry.add_to_hass(hass) + + # test connection error raising ConfigEntryNotReady + with patch( + "transmissionrpc.Client", + side_effect=TransmissionError("111: Connection refused"), + ), pytest.raises(ConfigEntryNotReady): + + await transmission.async_setup_entry(hass, entry) + + # test Authentication error returning false + + with patch( + "transmissionrpc.Client", side_effect=TransmissionError("401: Unauthorized") + ): + + assert await transmission.async_setup_entry(hass, entry) is False + + +async def test_unload_entry(hass, api): + """Test removing transmission client.""" + entry = MOCK_ENTRY + entry.add_to_hass(hass) + + with patch.object( + hass.config_entries, "async_forward_entry_unload", return_value=mock_coro(True) + ) as unload_entry: + assert await transmission.async_setup_entry(hass, entry) + + assert await transmission.async_unload_entry(hass, entry) + assert unload_entry.call_count == 2 + assert entry.entry_id not in hass.data[transmission.DOMAIN]