Sabnzbd config flow improvments (#70981)

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Shai Ungar 2022-04-28 23:25:17 +03:00 committed by GitHub
parent a9ca774e7e
commit 8883f5482b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 181 additions and 88 deletions

View File

@ -1,24 +1,39 @@
"""Support for monitoring an SABnzbd NZB client.""" """Support for monitoring an SABnzbd NZB client."""
from collections.abc import Callable
import logging import logging
from pysabnzbd import SabnzbdApiException from pysabnzbd import SabnzbdApiException
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_PATH, CONF_URL from homeassistant.const import (
from homeassistant.core import HomeAssistant, ServiceCall CONF_API_KEY,
from homeassistant.exceptions import ConfigEntryNotReady CONF_HOST,
CONF_NAME,
CONF_PATH,
CONF_PORT,
CONF_SENSORS,
CONF_SSL,
CONF_URL,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import (
ATTR_API_KEY,
ATTR_SPEED, ATTR_SPEED,
DEFAULT_HOST,
DEFAULT_NAME, DEFAULT_NAME,
DEFAULT_PORT,
DEFAULT_SPEED_LIMIT, DEFAULT_SPEED_LIMIT,
DEFAULT_SSL,
DOMAIN, DOMAIN,
KEY_API, KEY_API,
KEY_API_DATA,
KEY_NAME, KEY_NAME,
SERVICE_PAUSE, SERVICE_PAUSE,
SERVICE_RESUME, SERVICE_RESUME,
@ -27,23 +42,50 @@ from .const import (
UPDATE_INTERVAL, UPDATE_INTERVAL,
) )
from .sab import get_client from .sab import get_client
from .sensor import SENSOR_KEYS
PLATFORMS = ["sensor"] PLATFORMS = ["sensor"]
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SPEED_LIMIT_SCHEMA = vol.Schema( SERVICES = (
{vol.Optional(ATTR_SPEED, default=DEFAULT_SPEED_LIMIT): cv.string} SERVICE_PAUSE,
SERVICE_RESUME,
SERVICE_SET_SPEED,
)
SERVICE_BASE_SCHEMA = vol.Schema(
{
vol.Required(ATTR_API_KEY): cv.string,
}
)
SERVICE_SPEED_SCHEMA = SERVICE_BASE_SCHEMA.extend(
{
vol.Optional(ATTR_SPEED, default=DEFAULT_SPEED_LIMIT): cv.string,
}
) )
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.Schema( DOMAIN: vol.Schema(
{ vol.All(
vol.Required(CONF_API_KEY): str, cv.deprecated(CONF_HOST),
vol.Optional(CONF_NAME, default=DEFAULT_NAME): str, cv.deprecated(CONF_PORT),
vol.Required(CONF_URL): str, cv.deprecated(CONF_SENSORS),
vol.Optional(CONF_PATH): str, cv.deprecated(CONF_SSL),
} {
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): str,
vol.Required(CONF_URL): str,
vol.Optional(CONF_PATH): str,
vol.Optional(CONF_HOST, default=DEFAULT_HOST): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_SENSORS): vol.All(
cv.ensure_list, [vol.In(SENSOR_KEYS)]
),
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
},
)
) )
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
@ -69,42 +111,73 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
@callback
def async_get_entry_id_for_service_call(hass: HomeAssistant, call: ServiceCall) -> str:
"""Get the entry ID related to a service call (by device ID)."""
call_data_api_key = call.data[ATTR_API_KEY]
for entry in hass.config_entries.async_entries(DOMAIN):
if entry.data[ATTR_API_KEY] == call_data_api_key:
return entry.entry_id
raise ValueError(f"No api for API key: {call_data_api_key}")
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the SabNzbd Component.""" """Set up the SabNzbd Component."""
sab_api = await get_client(hass, entry.data) sab_api = await get_client(hass, entry.data)
if not sab_api: if not sab_api:
raise ConfigEntryNotReady raise ConfigEntryNotReady
sab_api_data = SabnzbdApiData(sab_api)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = { hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {
KEY_API: sab_api, KEY_API: sab_api,
KEY_API_DATA: sab_api_data,
KEY_NAME: entry.data[CONF_NAME], KEY_NAME: entry.data[CONF_NAME],
} }
hass.config_entries.async_setup_platforms(entry, PLATFORMS) @callback
def extract_api(func: Callable) -> Callable:
"""Define a decorator to get the correct api for a service call."""
sab_api_data = SabnzbdApiData(sab_api) async def wrapper(call: ServiceCall) -> None:
"""Wrap the service function."""
entry_id = async_get_entry_id_for_service_call(hass, call)
api_data = hass.data[DOMAIN][entry_id][KEY_API_DATA]
async def async_service_handler(service: ServiceCall) -> None: try:
"""Handle service calls.""" await func(call, api_data)
if service.service == SERVICE_PAUSE: except Exception as err:
await sab_api_data.async_pause_queue() raise HomeAssistantError(
elif service.service == SERVICE_RESUME: f"Error while executing {func.__name__}: {err}"
await sab_api_data.async_resume_queue() ) from err
elif service.service == SERVICE_SET_SPEED:
speed = service.data.get(ATTR_SPEED)
await sab_api_data.async_set_queue_speed(speed)
hass.services.async_register( return wrapper
DOMAIN, SERVICE_PAUSE, async_service_handler, schema=vol.Schema({})
)
hass.services.async_register( @extract_api
DOMAIN, SERVICE_RESUME, async_service_handler, schema=vol.Schema({}) async def async_pause_queue(call: ServiceCall, api: SabnzbdApiData) -> None:
) await api.async_pause_queue()
hass.services.async_register( @extract_api
DOMAIN, SERVICE_SET_SPEED, async_service_handler, schema=SPEED_LIMIT_SCHEMA async def async_resume_queue(call: ServiceCall, api: SabnzbdApiData) -> None:
) await api.async_resume_queue()
@extract_api
async def async_set_queue_speed(call: ServiceCall, api: SabnzbdApiData) -> None:
speed = call.data.get(ATTR_SPEED)
await api.async_set_queue_speed(speed)
for service, method, schema in (
(SERVICE_PAUSE, async_pause_queue, SERVICE_BASE_SCHEMA),
(SERVICE_RESUME, async_resume_queue, SERVICE_BASE_SCHEMA),
(SERVICE_SET_SPEED, async_set_queue_speed, SERVICE_SPEED_SCHEMA),
):
if hass.services.has_service(DOMAIN, service):
continue
hass.services.async_register(DOMAIN, service, method, schema=schema)
async def async_update_sabnzbd(now): async def async_update_sabnzbd(now):
"""Refresh SABnzbd queue data.""" """Refresh SABnzbd queue data."""
@ -115,10 +188,31 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
_LOGGER.error(err) _LOGGER.error(err)
async_track_time_interval(hass, async_update_sabnzbd, UPDATE_INTERVAL) async_track_time_interval(hass, async_update_sabnzbd, UPDATE_INTERVAL)
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
return True return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a Sabnzbd config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
loaded_entries = [
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# If this is the last loaded instance of Sabnzbd, deregister any services
# defined during integration setup:
for service_name in SERVICES:
hass.services.async_remove(DOMAIN, service_name)
return unload_ok
class SabnzbdApiData: class SabnzbdApiData:
"""Class for storing/refreshing sabnzbd api queue data.""" """Class for storing/refreshing sabnzbd api queue data."""

View File

@ -70,10 +70,8 @@ class SABnzbdConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_import(self, import_data): async def async_step_import(self, import_data):
"""Import sabnzbd config from configuration.yaml.""" """Import sabnzbd config from configuration.yaml."""
import_data[CONF_URL] = ( protocol = "https://" if import_data[CONF_SSL] else "http://"
("https://" if import_data[CONF_SSL] else "http://") import_data[
+ import_data[CONF_HOST] CONF_URL
+ ":" ] = f"{protocol}{import_data[CONF_HOST]}:{import_data[CONF_PORT]}"
+ str(import_data[CONF_PORT])
)
return await self.async_step_user(import_data) return await self.async_step_user(import_data)

View File

@ -5,8 +5,8 @@ DOMAIN = "sabnzbd"
DATA_SABNZBD = "sabnzbd" DATA_SABNZBD = "sabnzbd"
ATTR_SPEED = "speed" ATTR_SPEED = "speed"
BASE_URL_FORMAT = "{}://{}:{}/" ATTR_API_KEY = "api_key"
CONFIG_FILE = "sabnzbd.conf"
DEFAULT_HOST = "localhost" DEFAULT_HOST = "localhost"
DEFAULT_NAME = "SABnzbd" DEFAULT_NAME = "SABnzbd"
DEFAULT_PORT = 8080 DEFAULT_PORT = 8080
@ -22,4 +22,5 @@ SERVICE_SET_SPEED = "set_speed"
SIGNAL_SABNZBD_UPDATED = "sabnzbd_updated" SIGNAL_SABNZBD_UPDATED = "sabnzbd_updated"
KEY_API = "api" KEY_API = "api"
KEY_API_DATA = "api_data"
KEY_NAME = "name" KEY_NAME = "name"

View File

@ -1,10 +0,0 @@
"""Errors for the Sabnzbd component."""
from homeassistant.exceptions import HomeAssistantError
class AuthenticationError(HomeAssistantError):
"""Wrong Username or Password."""
class UnknownError(HomeAssistantError):
"""Unknown Error."""

View File

@ -10,12 +10,12 @@ from homeassistant.components.sensor import (
) )
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from . import DOMAIN, SIGNAL_SABNZBD_UPDATED, SabnzbdApiData from . import DOMAIN, SIGNAL_SABNZBD_UPDATED
from ...config_entries import ConfigEntry from ...config_entries import ConfigEntry
from ...const import DATA_GIGABYTES, DATA_MEGABYTES, DATA_RATE_MEGABYTES_PER_SECOND from ...const import DATA_GIGABYTES, DATA_MEGABYTES, DATA_RATE_MEGABYTES_PER_SECOND
from ...core import HomeAssistant from ...core import HomeAssistant
from ...helpers.entity_platform import AddEntitiesCallback from ...helpers.entity_platform import AddEntitiesCallback
from .const import KEY_API, KEY_NAME from .const import KEY_API_DATA, KEY_NAME
@dataclass @dataclass
@ -109,9 +109,8 @@ async def async_setup_entry(
) -> None: ) -> None:
"""Set up a Sabnzbd sensor entry.""" """Set up a Sabnzbd sensor entry."""
sab_api = hass.data[DOMAIN][config_entry.entry_id][KEY_API] sab_api_data = hass.data[DOMAIN][config_entry.entry_id][KEY_API_DATA]
client_name = hass.data[DOMAIN][config_entry.entry_id][KEY_NAME] client_name = hass.data[DOMAIN][config_entry.entry_id][KEY_NAME]
sab_api_data = SabnzbdApiData(sab_api)
async_add_entities( async_add_entities(
[SabnzbdSensor(sab_api_data, client_name, sensor) for sensor in SENSOR_TYPES] [SabnzbdSensor(sab_api_data, client_name, sensor) for sensor in SENSOR_TYPES]

View File

@ -1,13 +1,33 @@
pause: pause:
name: Pause name: Pause
description: Pauses downloads. description: Pauses downloads.
fields:
api_key:
name: Sabnzbd API key
description: The Sabnzbd API key to pause downloads
required: true
selector:
text:
resume: resume:
name: Resume name: Resume
description: Resumes downloads. description: Resumes downloads.
fields:
api_key:
name: Sabnzbd API key
description: The Sabnzbd API key to resume downloads
required: true
selector:
text:
set_speed: set_speed:
name: Set speed name: Set speed
description: Sets the download speed limit. description: Sets the download speed limit.
fields: fields:
api_key:
name: Sabnzbd API key
description: The Sabnzbd API key to set speed limit
required: true
selector:
text:
speed: speed:
name: Speed name: Speed
description: Speed limit. If specified as a number with no units, will be interpreted as a percent. If units are provided (e.g., 500K) will be interpreted absolutely. description: Speed limit. If specified as a number with no units, will be interpreted as a percent. If units are provided (e.g., 500K) will be interpreted absolutely.

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
from pysabnzbd import SabnzbdApiException from pysabnzbd import SabnzbdApiException
from homeassistant import data_entry_flow from homeassistant import config_entries, data_entry_flow
from homeassistant.components.sabnzbd import DOMAIN from homeassistant.components.sabnzbd import DOMAIN
from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER
from homeassistant.const import ( from homeassistant.const import (
@ -15,8 +15,7 @@ from homeassistant.const import (
CONF_SSL, CONF_SSL,
CONF_URL, CONF_URL,
) )
from homeassistant.data_entry_flow import RESULT_TYPE_FORM
from tests.common import MockConfigEntry
VALID_CONFIG = { VALID_CONFIG = {
CONF_NAME: "Sabnzbd", CONF_NAME: "Sabnzbd",
@ -37,21 +36,34 @@ VALID_CONFIG_OLD = {
async def test_create_entry(hass): async def test_create_entry(hass):
"""Test that the user step works.""" """Test that the user step works."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == RESULT_TYPE_FORM
assert result["errors"] == {}
with patch( with patch(
"homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available", "homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available",
return_value=True, return_value=True,
): ), patch(
result = await hass.config_entries.flow.async_init( "homeassistant.components.sabnzbd.async_setup_entry",
DOMAIN, return_value=True,
context={"source": SOURCE_USER}, ) as mock_setup_entry:
data=VALID_CONFIG, result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
VALID_CONFIG,
) )
await hass.async_block_till_done()
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result2["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "edc3eee7330e" assert result2["title"] == "edc3eee7330e"
assert result["data"][CONF_NAME] == "Sabnzbd" assert result2["data"] == {
assert result["data"][CONF_API_KEY] == "edc3eee7330e4fdda04489e3fbc283d0" CONF_API_KEY: "edc3eee7330e4fdda04489e3fbc283d0",
assert result["data"][CONF_PATH] == "" CONF_NAME: "Sabnzbd",
CONF_PATH: "",
CONF_URL: "http://localhost:8080",
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_auth_error(hass): async def test_auth_error(hass):
@ -69,27 +81,6 @@ async def test_auth_error(hass):
assert result["errors"] == {"base": "cannot_connect"} assert result["errors"] == {"base": "cannot_connect"}
async def test_integration_already_exists(hass):
"""Test we only allow a single config flow."""
with patch(
"homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available",
return_value=True,
):
MockConfigEntry(
domain=DOMAIN,
unique_id="123456",
data=VALID_CONFIG,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data=VALID_CONFIG,
)
assert result["type"] == "create_entry"
async def test_import_flow(hass) -> None: async def test_import_flow(hass) -> None:
"""Test the import configuration flow.""" """Test the import configuration flow."""
with patch( with patch(