Sabnzbd config flow improvments (#70981)

Co-authored-by: epenet <6771947+epenet@users.noreply.github.com>
Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Shai Ungar 2022-04-28 23:25:17 +03:00 committed by GitHub
parent a9ca774e7e
commit 8883f5482b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 181 additions and 88 deletions

View File

@ -1,24 +1,39 @@
"""Support for monitoring an SABnzbd NZB client."""
from collections.abc import Callable
import logging
from pysabnzbd import SabnzbdApiException
import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_PATH, CONF_URL
from homeassistant.core import HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry, ConfigEntryState
from homeassistant.const import (
CONF_API_KEY,
CONF_HOST,
CONF_NAME,
CONF_PATH,
CONF_PORT,
CONF_SENSORS,
CONF_SSL,
CONF_URL,
)
from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.event import async_track_time_interval
from homeassistant.helpers.typing import ConfigType
from .const import (
ATTR_API_KEY,
ATTR_SPEED,
DEFAULT_HOST,
DEFAULT_NAME,
DEFAULT_PORT,
DEFAULT_SPEED_LIMIT,
DEFAULT_SSL,
DOMAIN,
KEY_API,
KEY_API_DATA,
KEY_NAME,
SERVICE_PAUSE,
SERVICE_RESUME,
@ -27,23 +42,50 @@ from .const import (
UPDATE_INTERVAL,
)
from .sab import get_client
from .sensor import SENSOR_KEYS
PLATFORMS = ["sensor"]
_LOGGER = logging.getLogger(__name__)
SPEED_LIMIT_SCHEMA = vol.Schema(
{vol.Optional(ATTR_SPEED, default=DEFAULT_SPEED_LIMIT): cv.string}
SERVICES = (
SERVICE_PAUSE,
SERVICE_RESUME,
SERVICE_SET_SPEED,
)
SERVICE_BASE_SCHEMA = vol.Schema(
{
vol.Required(ATTR_API_KEY): cv.string,
}
)
SERVICE_SPEED_SCHEMA = SERVICE_BASE_SCHEMA.extend(
{
vol.Optional(ATTR_SPEED, default=DEFAULT_SPEED_LIMIT): cv.string,
}
)
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): str,
vol.Required(CONF_URL): str,
vol.Optional(CONF_PATH): str,
}
vol.All(
cv.deprecated(CONF_HOST),
cv.deprecated(CONF_PORT),
cv.deprecated(CONF_SENSORS),
cv.deprecated(CONF_SSL),
{
vol.Required(CONF_API_KEY): str,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): str,
vol.Required(CONF_URL): str,
vol.Optional(CONF_PATH): str,
vol.Optional(CONF_HOST, default=DEFAULT_HOST): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_SENSORS): vol.All(
cv.ensure_list, [vol.In(SENSOR_KEYS)]
),
vol.Optional(CONF_SSL, default=DEFAULT_SSL): cv.boolean,
},
)
)
},
extra=vol.ALLOW_EXTRA,
@ -69,42 +111,73 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True
@callback
def async_get_entry_id_for_service_call(hass: HomeAssistant, call: ServiceCall) -> str:
"""Get the entry ID related to a service call (by device ID)."""
call_data_api_key = call.data[ATTR_API_KEY]
for entry in hass.config_entries.async_entries(DOMAIN):
if entry.data[ATTR_API_KEY] == call_data_api_key:
return entry.entry_id
raise ValueError(f"No api for API key: {call_data_api_key}")
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up the SabNzbd Component."""
sab_api = await get_client(hass, entry.data)
if not sab_api:
raise ConfigEntryNotReady
sab_api_data = SabnzbdApiData(sab_api)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {
KEY_API: sab_api,
KEY_API_DATA: sab_api_data,
KEY_NAME: entry.data[CONF_NAME],
}
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
@callback
def extract_api(func: Callable) -> Callable:
"""Define a decorator to get the correct api for a service call."""
sab_api_data = SabnzbdApiData(sab_api)
async def wrapper(call: ServiceCall) -> None:
"""Wrap the service function."""
entry_id = async_get_entry_id_for_service_call(hass, call)
api_data = hass.data[DOMAIN][entry_id][KEY_API_DATA]
async def async_service_handler(service: ServiceCall) -> None:
"""Handle service calls."""
if service.service == SERVICE_PAUSE:
await sab_api_data.async_pause_queue()
elif service.service == SERVICE_RESUME:
await sab_api_data.async_resume_queue()
elif service.service == SERVICE_SET_SPEED:
speed = service.data.get(ATTR_SPEED)
await sab_api_data.async_set_queue_speed(speed)
try:
await func(call, api_data)
except Exception as err:
raise HomeAssistantError(
f"Error while executing {func.__name__}: {err}"
) from err
hass.services.async_register(
DOMAIN, SERVICE_PAUSE, async_service_handler, schema=vol.Schema({})
)
return wrapper
hass.services.async_register(
DOMAIN, SERVICE_RESUME, async_service_handler, schema=vol.Schema({})
)
@extract_api
async def async_pause_queue(call: ServiceCall, api: SabnzbdApiData) -> None:
await api.async_pause_queue()
hass.services.async_register(
DOMAIN, SERVICE_SET_SPEED, async_service_handler, schema=SPEED_LIMIT_SCHEMA
)
@extract_api
async def async_resume_queue(call: ServiceCall, api: SabnzbdApiData) -> None:
await api.async_resume_queue()
@extract_api
async def async_set_queue_speed(call: ServiceCall, api: SabnzbdApiData) -> None:
speed = call.data.get(ATTR_SPEED)
await api.async_set_queue_speed(speed)
for service, method, schema in (
(SERVICE_PAUSE, async_pause_queue, SERVICE_BASE_SCHEMA),
(SERVICE_RESUME, async_resume_queue, SERVICE_BASE_SCHEMA),
(SERVICE_SET_SPEED, async_set_queue_speed, SERVICE_SPEED_SCHEMA),
):
if hass.services.has_service(DOMAIN, service):
continue
hass.services.async_register(DOMAIN, service, method, schema=schema)
async def async_update_sabnzbd(now):
"""Refresh SABnzbd queue data."""
@ -115,10 +188,31 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
_LOGGER.error(err)
async_track_time_interval(hass, async_update_sabnzbd, UPDATE_INTERVAL)
hass.config_entries.async_setup_platforms(entry, PLATFORMS)
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a Sabnzbd config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)
loaded_entries = [
entry
for entry in hass.config_entries.async_entries(DOMAIN)
if entry.state == ConfigEntryState.LOADED
]
if len(loaded_entries) == 1:
# If this is the last loaded instance of Sabnzbd, deregister any services
# defined during integration setup:
for service_name in SERVICES:
hass.services.async_remove(DOMAIN, service_name)
return unload_ok
class SabnzbdApiData:
"""Class for storing/refreshing sabnzbd api queue data."""

View File

@ -70,10 +70,8 @@ class SABnzbdConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
async def async_step_import(self, import_data):
"""Import sabnzbd config from configuration.yaml."""
import_data[CONF_URL] = (
("https://" if import_data[CONF_SSL] else "http://")
+ import_data[CONF_HOST]
+ ":"
+ str(import_data[CONF_PORT])
)
protocol = "https://" if import_data[CONF_SSL] else "http://"
import_data[
CONF_URL
] = f"{protocol}{import_data[CONF_HOST]}:{import_data[CONF_PORT]}"
return await self.async_step_user(import_data)

View File

@ -5,8 +5,8 @@ DOMAIN = "sabnzbd"
DATA_SABNZBD = "sabnzbd"
ATTR_SPEED = "speed"
BASE_URL_FORMAT = "{}://{}:{}/"
CONFIG_FILE = "sabnzbd.conf"
ATTR_API_KEY = "api_key"
DEFAULT_HOST = "localhost"
DEFAULT_NAME = "SABnzbd"
DEFAULT_PORT = 8080
@ -22,4 +22,5 @@ SERVICE_SET_SPEED = "set_speed"
SIGNAL_SABNZBD_UPDATED = "sabnzbd_updated"
KEY_API = "api"
KEY_API_DATA = "api_data"
KEY_NAME = "name"

View File

@ -1,10 +0,0 @@
"""Errors for the Sabnzbd component."""
from homeassistant.exceptions import HomeAssistantError
class AuthenticationError(HomeAssistantError):
"""Wrong Username or Password."""
class UnknownError(HomeAssistantError):
"""Unknown Error."""

View File

@ -10,12 +10,12 @@ from homeassistant.components.sensor import (
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from . import DOMAIN, SIGNAL_SABNZBD_UPDATED, SabnzbdApiData
from . import DOMAIN, SIGNAL_SABNZBD_UPDATED
from ...config_entries import ConfigEntry
from ...const import DATA_GIGABYTES, DATA_MEGABYTES, DATA_RATE_MEGABYTES_PER_SECOND
from ...core import HomeAssistant
from ...helpers.entity_platform import AddEntitiesCallback
from .const import KEY_API, KEY_NAME
from .const import KEY_API_DATA, KEY_NAME
@dataclass
@ -109,9 +109,8 @@ async def async_setup_entry(
) -> None:
"""Set up a Sabnzbd sensor entry."""
sab_api = hass.data[DOMAIN][config_entry.entry_id][KEY_API]
sab_api_data = hass.data[DOMAIN][config_entry.entry_id][KEY_API_DATA]
client_name = hass.data[DOMAIN][config_entry.entry_id][KEY_NAME]
sab_api_data = SabnzbdApiData(sab_api)
async_add_entities(
[SabnzbdSensor(sab_api_data, client_name, sensor) for sensor in SENSOR_TYPES]

View File

@ -1,13 +1,33 @@
pause:
name: Pause
description: Pauses downloads.
fields:
api_key:
name: Sabnzbd API key
description: The Sabnzbd API key to pause downloads
required: true
selector:
text:
resume:
name: Resume
description: Resumes downloads.
fields:
api_key:
name: Sabnzbd API key
description: The Sabnzbd API key to resume downloads
required: true
selector:
text:
set_speed:
name: Set speed
description: Sets the download speed limit.
fields:
api_key:
name: Sabnzbd API key
description: The Sabnzbd API key to set speed limit
required: true
selector:
text:
speed:
name: Speed
description: Speed limit. If specified as a number with no units, will be interpreted as a percent. If units are provided (e.g., 500K) will be interpreted absolutely.

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
from pysabnzbd import SabnzbdApiException
from homeassistant import data_entry_flow
from homeassistant import config_entries, data_entry_flow
from homeassistant.components.sabnzbd import DOMAIN
from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER
from homeassistant.const import (
@ -15,8 +15,7 @@ from homeassistant.const import (
CONF_SSL,
CONF_URL,
)
from tests.common import MockConfigEntry
from homeassistant.data_entry_flow import RESULT_TYPE_FORM
VALID_CONFIG = {
CONF_NAME: "Sabnzbd",
@ -37,21 +36,34 @@ VALID_CONFIG_OLD = {
async def test_create_entry(hass):
"""Test that the user step works."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == RESULT_TYPE_FORM
assert result["errors"] == {}
with patch(
"homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available",
return_value=True,
):
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data=VALID_CONFIG,
), patch(
"homeassistant.components.sabnzbd.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
VALID_CONFIG,
)
await hass.async_block_till_done()
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "edc3eee7330e"
assert result["data"][CONF_NAME] == "Sabnzbd"
assert result["data"][CONF_API_KEY] == "edc3eee7330e4fdda04489e3fbc283d0"
assert result["data"][CONF_PATH] == ""
assert result2["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result2["title"] == "edc3eee7330e"
assert result2["data"] == {
CONF_API_KEY: "edc3eee7330e4fdda04489e3fbc283d0",
CONF_NAME: "Sabnzbd",
CONF_PATH: "",
CONF_URL: "http://localhost:8080",
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_auth_error(hass):
@ -69,27 +81,6 @@ async def test_auth_error(hass):
assert result["errors"] == {"base": "cannot_connect"}
async def test_integration_already_exists(hass):
"""Test we only allow a single config flow."""
with patch(
"homeassistant.components.sabnzbd.sab.SabnzbdApi.check_available",
return_value=True,
):
MockConfigEntry(
domain=DOMAIN,
unique_id="123456",
data=VALID_CONFIG,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data=VALID_CONFIG,
)
assert result["type"] == "create_entry"
async def test_import_flow(hass) -> None:
"""Test the import configuration flow."""
with patch(