From 8883f5482b8cb225f37d1299f4c08c8d1cc58845 Mon Sep 17 00:00:00 2001 From: Shai Ungar Date: Thu, 28 Apr 2022 23:25:17 +0300 Subject: [PATCH] Sabnzbd config flow improvments (#70981) Co-authored-by: epenet <6771947+epenet@users.noreply.github.com> Co-authored-by: Martin Hjelmare --- homeassistant/components/sabnzbd/__init__.py | 158 ++++++++++++++---- .../components/sabnzbd/config_flow.py | 10 +- homeassistant/components/sabnzbd/const.py | 5 +- homeassistant/components/sabnzbd/errors.py | 10 -- homeassistant/components/sabnzbd/sensor.py | 7 +- .../components/sabnzbd/services.yaml | 20 +++ tests/components/sabnzbd/test_config_flow.py | 59 +++---- 7 files changed, 181 insertions(+), 88 deletions(-) delete mode 100644 homeassistant/components/sabnzbd/errors.py diff --git a/homeassistant/components/sabnzbd/__init__.py b/homeassistant/components/sabnzbd/__init__.py index bbbfbe18bc1..aca50e404a2 100644 --- a/homeassistant/components/sabnzbd/__init__.py +++ b/homeassistant/components/sabnzbd/__init__.py @@ -1,24 +1,39 @@ """Support for monitoring an SABnzbd NZB client.""" +from collections.abc import Callable import logging from pysabnzbd import SabnzbdApiException import voluptuous as vol -from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry -from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_PATH, CONF_URL -from homeassistant.core import HomeAssistant, ServiceCall -from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryState +from homeassistant.const import ( + CONF_API_KEY, + CONF_HOST, + CONF_NAME, + CONF_PATH, + CONF_PORT, + CONF_SENSORS, + CONF_SSL, + CONF_URL, +) +from homeassistant.core import HomeAssistant, ServiceCall, callback +from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType from .const import ( + ATTR_API_KEY, ATTR_SPEED, + DEFAULT_HOST, DEFAULT_NAME, + DEFAULT_PORT, DEFAULT_SPEED_LIMIT, + DEFAULT_SSL, DOMAIN, KEY_API, + KEY_API_DATA, KEY_NAME, SERVICE_PAUSE, SERVICE_RESUME, @@ -27,23 +42,50 @@ from .const import ( UPDATE_INTERVAL, ) from .sab import get_client +from .sensor import SENSOR_KEYS PLATFORMS = ["sensor"] _LOGGER = logging.getLogger(__name__) -SPEED_LIMIT_SCHEMA = vol.Schema( - {vol.Optional(ATTR_SPEED, default=DEFAULT_SPEED_LIMIT): cv.string} +SERVICES = ( + SERVICE_PAUSE, + SERVICE_RESUME, + SERVICE_SET_SPEED, +) + +SERVICE_BASE_SCHEMA = vol.Schema( + { + vol.Required(ATTR_API_KEY): cv.string, + } +) + +SERVICE_SPEED_SCHEMA = SERVICE_BASE_SCHEMA.extend( + { + vol.Optional(ATTR_SPEED, default=DEFAULT_SPEED_LIMIT): cv.string, + } ) CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.Schema( - { - vol.Required(CONF_API_KEY): str, - vol.Optional(CONF_NAME, default=DEFAULT_NAME): str, - vol.Required(CONF_URL): str, - vol.Optional(CONF_PATH): str, - } + vol.All( + cv.deprecated(CONF_HOST), + cv.deprecated(CONF_PORT), + cv.deprecated(CONF_SENSORS), + cv.deprecated(CONF_SSL), + { + vol.Required(CONF_API_KEY): str, + vol.Optional(CONF_NAME, default=DEFAULT_NAME): str, + vol.Required(CONF_URL): str, + vol.Optional(CONF_PATH): str, + vol.Optional(CONF_HOST, default=DEFAULT_HOST): cv.string, + vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, + vol.Optional(CONF_SENSORS): vol.All( + cv.ensure_list, [vol.In(SENSOR_KEYS)] + ), + vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean, + }, + ) ) }, extra=vol.ALLOW_EXTRA, @@ -69,42 +111,73 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: return True +@callback +def async_get_entry_id_for_service_call(hass: HomeAssistant, call: ServiceCall) -> str: + """Get the entry ID related to a service call (by device ID).""" + call_data_api_key = call.data[ATTR_API_KEY] + + for entry in hass.config_entries.async_entries(DOMAIN): + if entry.data[ATTR_API_KEY] == call_data_api_key: + return entry.entry_id + + raise ValueError(f"No api for API key: {call_data_api_key}") + + async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up the SabNzbd Component.""" sab_api = await get_client(hass, entry.data) if not sab_api: raise ConfigEntryNotReady + sab_api_data = SabnzbdApiData(sab_api) + hass.data.setdefault(DOMAIN, {})[entry.entry_id] = { KEY_API: sab_api, + KEY_API_DATA: sab_api_data, KEY_NAME: entry.data[CONF_NAME], } - hass.config_entries.async_setup_platforms(entry, PLATFORMS) + @callback + def extract_api(func: Callable) -> Callable: + """Define a decorator to get the correct api for a service call.""" - sab_api_data = SabnzbdApiData(sab_api) + async def wrapper(call: ServiceCall) -> None: + """Wrap the service function.""" + entry_id = async_get_entry_id_for_service_call(hass, call) + api_data = hass.data[DOMAIN][entry_id][KEY_API_DATA] - async def async_service_handler(service: ServiceCall) -> None: - """Handle service calls.""" - if service.service == SERVICE_PAUSE: - await sab_api_data.async_pause_queue() - elif service.service == SERVICE_RESUME: - await sab_api_data.async_resume_queue() - elif service.service == SERVICE_SET_SPEED: - speed = service.data.get(ATTR_SPEED) - await sab_api_data.async_set_queue_speed(speed) + try: + await func(call, api_data) + except Exception as err: + raise HomeAssistantError( + f"Error while executing {func.__name__}: {err}" + ) from err - hass.services.async_register( - DOMAIN, SERVICE_PAUSE, async_service_handler, schema=vol.Schema({}) - ) + return wrapper - hass.services.async_register( - DOMAIN, SERVICE_RESUME, async_service_handler, schema=vol.Schema({}) - ) + @extract_api + async def async_pause_queue(call: ServiceCall, api: SabnzbdApiData) -> None: + await api.async_pause_queue() - hass.services.async_register( - DOMAIN, SERVICE_SET_SPEED, async_service_handler, schema=SPEED_LIMIT_SCHEMA - ) + @extract_api + async def async_resume_queue(call: ServiceCall, api: SabnzbdApiData) -> None: + await api.async_resume_queue() + + @extract_api + async def async_set_queue_speed(call: ServiceCall, api: SabnzbdApiData) -> None: + speed = call.data.get(ATTR_SPEED) + await api.async_set_queue_speed(speed) + + for service, method, schema in ( + (SERVICE_PAUSE, async_pause_queue, SERVICE_BASE_SCHEMA), + (SERVICE_RESUME, async_resume_queue, SERVICE_BASE_SCHEMA), + (SERVICE_SET_SPEED, async_set_queue_speed, SERVICE_SPEED_SCHEMA), + ): + + if hass.services.has_service(DOMAIN, service): + continue + + hass.services.async_register(DOMAIN, service, method, schema=schema) async def async_update_sabnzbd(now): """Refresh SABnzbd queue data.""" @@ -115,10 +188,31 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: _LOGGER.error(err) async_track_time_interval(hass, async_update_sabnzbd, UPDATE_INTERVAL) + hass.config_entries.async_setup_platforms(entry, PLATFORMS) return True +async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: + """Unload a Sabnzbd config entry.""" + unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + if unload_ok: + hass.data[DOMAIN].pop(entry.entry_id) + + loaded_entries = [ + entry + for entry in hass.config_entries.async_entries(DOMAIN) + if entry.state == ConfigEntryState.LOADED + ] + if len(loaded_entries) == 1: + # If this is the last loaded instance of Sabnzbd, deregister any services + # defined during integration setup: + for service_name in SERVICES: + hass.services.async_remove(DOMAIN, service_name) + + return unload_ok + + class SabnzbdApiData: """Class for storing/refreshing sabnzbd api queue data.""" diff --git a/homeassistant/components/sabnzbd/config_flow.py b/homeassistant/components/sabnzbd/config_flow.py index 914b1febefc..7930363b2ac 100644 --- a/homeassistant/components/sabnzbd/config_flow.py +++ b/homeassistant/components/sabnzbd/config_flow.py @@ -70,10 +70,8 @@ class SABnzbdConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): async def async_step_import(self, import_data): """Import sabnzbd config from configuration.yaml.""" - import_data[CONF_URL] = ( - ("https://" if import_data[CONF_SSL] else "http://") - + import_data[CONF_HOST] - + ":" - + str(import_data[CONF_PORT]) - ) + protocol = "https://" if import_data[CONF_SSL] else "http://" + import_data[ + CONF_URL + ] = f"{protocol}{import_data[CONF_HOST]}:{import_data[CONF_PORT]}" return await self.async_step_user(import_data) diff --git a/homeassistant/components/sabnzbd/const.py b/homeassistant/components/sabnzbd/const.py index 9092b877b1b..8add1f61493 100644 --- a/homeassistant/components/sabnzbd/const.py +++ b/homeassistant/components/sabnzbd/const.py @@ -5,8 +5,8 @@ DOMAIN = "sabnzbd" DATA_SABNZBD = "sabnzbd" ATTR_SPEED = "speed" -BASE_URL_FORMAT = "{}://{}:{}/" -CONFIG_FILE = "sabnzbd.conf" +ATTR_API_KEY = "api_key" + DEFAULT_HOST = "localhost" DEFAULT_NAME = "SABnzbd" DEFAULT_PORT = 8080 @@ -22,4 +22,5 @@ SERVICE_SET_SPEED = "set_speed" SIGNAL_SABNZBD_UPDATED = "sabnzbd_updated" KEY_API = "api" +KEY_API_DATA = "api_data" KEY_NAME = "name" diff --git a/homeassistant/components/sabnzbd/errors.py b/homeassistant/components/sabnzbd/errors.py deleted file mode 100644 index a14a0af4775..00000000000 --- a/homeassistant/components/sabnzbd/errors.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Errors for the Sabnzbd component.""" -from homeassistant.exceptions import HomeAssistantError - - -class AuthenticationError(HomeAssistantError): - """Wrong Username or Password.""" - - -class UnknownError(HomeAssistantError): - """Unknown Error.""" diff --git a/homeassistant/components/sabnzbd/sensor.py b/homeassistant/components/sabnzbd/sensor.py index 293b14a604b..1d661d90848 100644 --- a/homeassistant/components/sabnzbd/sensor.py +++ b/homeassistant/components/sabnzbd/sensor.py @@ -10,12 +10,12 @@ from homeassistant.components.sensor import ( ) from homeassistant.helpers.dispatcher import async_dispatcher_connect -from . import DOMAIN, SIGNAL_SABNZBD_UPDATED, SabnzbdApiData +from . import DOMAIN, SIGNAL_SABNZBD_UPDATED from ...config_entries import ConfigEntry from ...const import DATA_GIGABYTES, DATA_MEGABYTES, DATA_RATE_MEGABYTES_PER_SECOND from ...core import HomeAssistant from ...helpers.entity_platform import AddEntitiesCallback -from .const import KEY_API, KEY_NAME +from .const import KEY_API_DATA, KEY_NAME @dataclass @@ -109,9 +109,8 @@ async def async_setup_entry( ) -> None: """Set up a Sabnzbd sensor entry.""" - sab_api = hass.data[DOMAIN][config_entry.entry_id][KEY_API] + sab_api_data = hass.data[DOMAIN][config_entry.entry_id][KEY_API_DATA] client_name = hass.data[DOMAIN][config_entry.entry_id][KEY_NAME] - sab_api_data = SabnzbdApiData(sab_api) async_add_entities( [SabnzbdSensor(sab_api_data, client_name, sensor) for sensor in SENSOR_TYPES] diff --git a/homeassistant/components/sabnzbd/services.yaml b/homeassistant/components/sabnzbd/services.yaml index 38f68bfe5dd..2221eed169f 100644 --- a/homeassistant/components/sabnzbd/services.yaml +++ b/homeassistant/components/sabnzbd/services.yaml @@ -1,13 +1,33 @@ pause: name: Pause description: Pauses downloads. + fields: + api_key: + name: Sabnzbd API key + description: The Sabnzbd API key to pause downloads + required: true + selector: + text: resume: name: Resume description: Resumes downloads. + fields: + api_key: + name: Sabnzbd API key + description: The Sabnzbd API key to resume downloads + required: true + selector: + text: set_speed: name: Set speed description: Sets the download speed limit. fields: + api_key: + name: Sabnzbd API key + description: The Sabnzbd API key to set speed limit + required: true + selector: + text: speed: name: Speed description: Speed limit. If specified as a number with no units, will be interpreted as a percent. If units are provided (e.g., 500K) will be interpreted absolutely. diff --git a/tests/components/sabnzbd/test_config_flow.py b/tests/components/sabnzbd/test_config_flow.py index 381928457d2..d04c5b18ab1 100644 --- a/tests/components/sabnzbd/test_config_flow.py +++ b/tests/components/sabnzbd/test_config_flow.py @@ -3,7 +3,7 @@ from unittest.mock import patch from pysabnzbd import SabnzbdApiException -from homeassistant import data_entry_flow +from homeassistant import config_entries, data_entry_flow from homeassistant.components.sabnzbd import DOMAIN from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER from homeassistant.const import ( @@ -15,8 +15,7 @@ from homeassistant.const import ( CONF_SSL, CONF_URL, ) - -from tests.common import MockConfigEntry +from homeassistant.data_entry_flow import RESULT_TYPE_FORM VALID_CONFIG = { CONF_NAME: "Sabnzbd", @@ -37,21 +36,34 @@ VALID_CONFIG_OLD = { async def test_create_entry(hass): """Test that the user step works.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == RESULT_TYPE_FORM + assert result["errors"] == {} + with patch( "homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available", return_value=True, - ): - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": SOURCE_USER}, - data=VALID_CONFIG, + ), patch( + "homeassistant.components.sabnzbd.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + VALID_CONFIG, ) + await hass.async_block_till_done() - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == "edc3eee7330e" - assert result["data"][CONF_NAME] == "Sabnzbd" - assert result["data"][CONF_API_KEY] == "edc3eee7330e4fdda04489e3fbc283d0" - assert result["data"][CONF_PATH] == "" + assert result2["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result2["title"] == "edc3eee7330e" + assert result2["data"] == { + CONF_API_KEY: "edc3eee7330e4fdda04489e3fbc283d0", + CONF_NAME: "Sabnzbd", + CONF_PATH: "", + CONF_URL: "http://localhost:8080", + } + assert len(mock_setup_entry.mock_calls) == 1 async def test_auth_error(hass): @@ -69,27 +81,6 @@ async def test_auth_error(hass): assert result["errors"] == {"base": "cannot_connect"} -async def test_integration_already_exists(hass): - """Test we only allow a single config flow.""" - with patch( - "homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available", - return_value=True, - ): - MockConfigEntry( - domain=DOMAIN, - unique_id="123456", - data=VALID_CONFIG, - ).add_to_hass(hass) - - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": SOURCE_USER}, - data=VALID_CONFIG, - ) - - assert result["type"] == "create_entry" - - async def test_import_flow(hass) -> None: """Test the import configuration flow.""" with patch(