Dns update (#1393)

* Improvements to DNS validator to include IPv6 (#1312)

* improvements to DNS validator to include IPv6

* fixed the DNS validators

* updated per suggestions

* Update const.py

* Update dns.py

* Update validate.py

* Update validate.py

* Update dns.py

* Update test_validate.py

* Update validate.py

* Cleanup

* Don't set default DNS server as default

* Remove update local resolver

* Fix lint
This commit is contained in:
Pascal Vizeli 2019-12-05 21:52:55 +01:00 committed by GitHub
parent f5c171e44f
commit fc5d97562f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 164 additions and 102 deletions

View File

@ -90,12 +90,12 @@ from ..const import (
from ..coresys import CoreSys from ..coresys import CoreSys
from ..discovery.validate import valid_discovery_service from ..discovery.validate import valid_discovery_service
from ..validate import ( from ..validate import (
ALSA_DEVICE, alsa_device,
DOCKER_PORTS, DOCKER_PORTS,
DOCKER_PORTS_DESCRIPTION, DOCKER_PORTS_DESCRIPTION,
NETWORK_PORT, network_port,
TOKEN, token,
UUID_MATCH, uuid_match,
) )
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -182,7 +182,7 @@ SCHEMA_ADDON_CONFIG = vol.Schema(
), ),
vol.Optional(ATTR_INGRESS, default=False): vol.Boolean(), vol.Optional(ATTR_INGRESS, default=False): vol.Boolean(),
vol.Optional(ATTR_INGRESS_PORT, default=8099): vol.Any( vol.Optional(ATTR_INGRESS_PORT, default=8099): vol.Any(
NETWORK_PORT, vol.Equal(0) network_port, vol.Equal(0)
), ),
vol.Optional(ATTR_INGRESS_ENTRY): vol.Coerce(str), vol.Optional(ATTR_INGRESS_ENTRY): vol.Coerce(str),
vol.Optional(ATTR_PANEL_ICON, default="mdi:puzzle"): vol.Coerce(str), vol.Optional(ATTR_PANEL_ICON, default="mdi:puzzle"): vol.Coerce(str),
@ -269,8 +269,8 @@ SCHEMA_ADDON_USER = vol.Schema(
{ {
vol.Required(ATTR_VERSION): vol.Coerce(str), vol.Required(ATTR_VERSION): vol.Coerce(str),
vol.Optional(ATTR_IMAGE): vol.Coerce(str), vol.Optional(ATTR_IMAGE): vol.Coerce(str),
vol.Optional(ATTR_UUID, default=lambda: uuid.uuid4().hex): UUID_MATCH, vol.Optional(ATTR_UUID, default=lambda: uuid.uuid4().hex): uuid_match,
vol.Optional(ATTR_ACCESS_TOKEN): TOKEN, vol.Optional(ATTR_ACCESS_TOKEN): token,
vol.Optional(ATTR_INGRESS_TOKEN, default=secrets.token_urlsafe): vol.Coerce( vol.Optional(ATTR_INGRESS_TOKEN, default=secrets.token_urlsafe): vol.Coerce(
str str
), ),
@ -278,8 +278,8 @@ SCHEMA_ADDON_USER = vol.Schema(
vol.Optional(ATTR_AUTO_UPDATE, default=False): vol.Boolean(), vol.Optional(ATTR_AUTO_UPDATE, default=False): vol.Boolean(),
vol.Optional(ATTR_BOOT): vol.In([BOOT_AUTO, BOOT_MANUAL]), vol.Optional(ATTR_BOOT): vol.In([BOOT_AUTO, BOOT_MANUAL]),
vol.Optional(ATTR_NETWORK): DOCKER_PORTS, vol.Optional(ATTR_NETWORK): DOCKER_PORTS,
vol.Optional(ATTR_AUDIO_OUTPUT): ALSA_DEVICE, vol.Optional(ATTR_AUDIO_OUTPUT): alsa_device,
vol.Optional(ATTR_AUDIO_INPUT): ALSA_DEVICE, vol.Optional(ATTR_AUDIO_INPUT): alsa_device,
vol.Optional(ATTR_PROTECTED, default=True): vol.Boolean(), vol.Optional(ATTR_PROTECTED, default=True): vol.Boolean(),
vol.Optional(ATTR_INGRESS_PANEL, default=False): vol.Boolean(), vol.Optional(ATTR_INGRESS_PANEL, default=False): vol.Boolean(),
}, },
@ -386,7 +386,7 @@ def _single_validate(coresys: CoreSys, typ: str, value: Any, key: str):
elif typ.startswith(V_URL): elif typ.startswith(V_URL):
return vol.Url()(value) return vol.Url()(value)
elif typ.startswith(V_PORT): elif typ.startswith(V_PORT):
return NETWORK_PORT(value) return network_port(value)
elif typ.startswith(V_MATCH): elif typ.startswith(V_MATCH):
return vol.Match(match.group("match"))(str(value)) return vol.Match(match.group("match"))(str(value))
elif typ.startswith(V_LIST): elif typ.startswith(V_LIST):

View File

@ -90,7 +90,7 @@ from ..const import (
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError from ..exceptions import APIError
from ..validate import ALSA_DEVICE, DOCKER_PORTS from ..validate import alsa_device, DOCKER_PORTS
from .utils import api_process, api_process_raw, api_validate from .utils import api_process, api_process_raw, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -103,8 +103,8 @@ SCHEMA_OPTIONS = vol.Schema(
vol.Optional(ATTR_BOOT): vol.In([BOOT_AUTO, BOOT_MANUAL]), vol.Optional(ATTR_BOOT): vol.In([BOOT_AUTO, BOOT_MANUAL]),
vol.Optional(ATTR_NETWORK): vol.Any(None, DOCKER_PORTS), vol.Optional(ATTR_NETWORK): vol.Any(None, DOCKER_PORTS),
vol.Optional(ATTR_AUTO_UPDATE): vol.Boolean(), vol.Optional(ATTR_AUTO_UPDATE): vol.Boolean(),
vol.Optional(ATTR_AUDIO_OUTPUT): ALSA_DEVICE, vol.Optional(ATTR_AUDIO_OUTPUT): alsa_device,
vol.Optional(ATTR_AUDIO_INPUT): ALSA_DEVICE, vol.Optional(ATTR_AUDIO_INPUT): alsa_device,
vol.Optional(ATTR_INGRESS_PANEL): vol.Boolean(), vol.Optional(ATTR_INGRESS_PANEL): vol.Boolean(),
} }
) )

View File

@ -24,13 +24,13 @@ from ..const import (
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError from ..exceptions import APIError
from ..validate import DNS_SERVER_LIST from ..validate import dns_server_list
from .utils import api_process, api_process_raw, api_validate from .utils import api_process, api_process_raw, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_SERVERS): DNS_SERVER_LIST}) SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_SERVERS): dns_server_list})
SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): vol.Coerce(str)}) SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): vol.Coerce(str)})

View File

@ -33,7 +33,7 @@ from ..const import (
) )
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError from ..exceptions import APIError
from ..validate import DOCKER_IMAGE, NETWORK_PORT from ..validate import docker_image, network_port
from .utils import api_process, api_process_raw, api_validate from .utils import api_process, api_process_raw, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -42,9 +42,9 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
SCHEMA_OPTIONS = vol.Schema( SCHEMA_OPTIONS = vol.Schema(
{ {
vol.Optional(ATTR_BOOT): vol.Boolean(), vol.Optional(ATTR_BOOT): vol.Boolean(),
vol.Inclusive(ATTR_IMAGE, "custom_hass"): vol.Maybe(DOCKER_IMAGE), vol.Inclusive(ATTR_IMAGE, "custom_hass"): vol.Maybe(docker_image),
vol.Inclusive(ATTR_LAST_VERSION, "custom_hass"): vol.Maybe(vol.Coerce(str)), vol.Inclusive(ATTR_LAST_VERSION, "custom_hass"): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_PORT): NETWORK_PORT, vol.Optional(ATTR_PORT): network_port,
vol.Optional(ATTR_PASSWORD): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_PASSWORD): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_SSL): vol.Boolean(), vol.Optional(ATTR_SSL): vol.Boolean(),
vol.Optional(ATTR_WATCHDOG): vol.Boolean(), vol.Optional(ATTR_WATCHDOG): vol.Boolean(),

View File

@ -41,7 +41,7 @@ from ..const import (
from ..coresys import CoreSysAttributes from ..coresys import CoreSysAttributes
from ..exceptions import APIError from ..exceptions import APIError
from ..utils.validate import validate_timezone from ..utils.validate import validate_timezone
from ..validate import CHANNELS, LOG_LEVEL, REPOSITORIES, WAIT_BOOT from ..validate import channels, log_level, repositories, wait_boot
from .utils import api_process, api_process_raw, api_validate from .utils import api_process, api_process_raw, api_validate
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -49,11 +49,11 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
SCHEMA_OPTIONS = vol.Schema( SCHEMA_OPTIONS = vol.Schema(
{ {
vol.Optional(ATTR_CHANNEL): CHANNELS, vol.Optional(ATTR_CHANNEL): channels,
vol.Optional(ATTR_ADDONS_REPOSITORIES): REPOSITORIES, vol.Optional(ATTR_ADDONS_REPOSITORIES): repositories,
vol.Optional(ATTR_TIMEZONE): validate_timezone, vol.Optional(ATTR_TIMEZONE): validate_timezone,
vol.Optional(ATTR_WAIT_BOOT): WAIT_BOOT, vol.Optional(ATTR_WAIT_BOOT): wait_boot,
vol.Optional(ATTR_LOGGING): LOG_LEVEL, vol.Optional(ATTR_LOGGING): log_level,
vol.Optional(ATTR_DEBUG): vol.Boolean(), vol.Optional(ATTR_DEBUG): vol.Boolean(),
vol.Optional(ATTR_DEBUG_BLOCK): vol.Boolean(), vol.Optional(ATTR_DEBUG_BLOCK): vol.Boolean(),
} }

View File

@ -1,11 +1,11 @@
"""Discovery service for AdGuard.""" """Discovery service for AdGuard."""
import voluptuous as vol import voluptuous as vol
from hassio.validate import NETWORK_PORT from hassio.validate import network_port
from ..const import ATTR_HOST, ATTR_PORT from ..const import ATTR_HOST, ATTR_PORT
SCHEMA = vol.Schema( SCHEMA = vol.Schema(
{vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_PORT): NETWORK_PORT} {vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_PORT): network_port}
) )

View File

@ -1,11 +1,11 @@
"""Discovery service for Almond.""" """Discovery service for Almond."""
import voluptuous as vol import voluptuous as vol
from hassio.validate import NETWORK_PORT from hassio.validate import network_port
from ..const import ATTR_HOST, ATTR_PORT from ..const import ATTR_HOST, ATTR_PORT
SCHEMA = vol.Schema( SCHEMA = vol.Schema(
{vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_PORT): NETWORK_PORT} {vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_PORT): network_port}
) )

View File

@ -1,7 +1,7 @@
"""Discovery service for MQTT.""" """Discovery service for MQTT."""
import voluptuous as vol import voluptuous as vol
from hassio.validate import NETWORK_PORT from hassio.validate import network_port
from ..const import ATTR_HOST, ATTR_PORT, ATTR_API_KEY, ATTR_SERIAL from ..const import ATTR_HOST, ATTR_PORT, ATTR_API_KEY, ATTR_SERIAL
@ -9,7 +9,7 @@ from ..const import ATTR_HOST, ATTR_PORT, ATTR_API_KEY, ATTR_SERIAL
SCHEMA = vol.Schema( SCHEMA = vol.Schema(
{ {
vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_HOST): vol.Coerce(str),
vol.Required(ATTR_PORT): NETWORK_PORT, vol.Required(ATTR_PORT): network_port,
vol.Required(ATTR_SERIAL): vol.Coerce(str), vol.Required(ATTR_SERIAL): vol.Coerce(str),
vol.Required(ATTR_API_KEY): vol.Coerce(str), vol.Required(ATTR_API_KEY): vol.Coerce(str),
} }

View File

@ -1,11 +1,11 @@
"""Discovery service for Home Panel.""" """Discovery service for Home Panel."""
import voluptuous as vol import voluptuous as vol
from hassio.validate import NETWORK_PORT from hassio.validate import network_port
from ..const import ATTR_HOST, ATTR_PORT from ..const import ATTR_HOST, ATTR_PORT
SCHEMA = vol.Schema( SCHEMA = vol.Schema(
{vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_PORT): NETWORK_PORT} {vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_PORT): network_port}
) )

View File

@ -1,7 +1,7 @@
"""Discovery service for MQTT.""" """Discovery service for MQTT."""
import voluptuous as vol import voluptuous as vol
from hassio.validate import NETWORK_PORT from hassio.validate import network_port
from ..const import ( from ..const import (
ATTR_HOST, ATTR_HOST,
@ -16,7 +16,7 @@ from ..const import (
SCHEMA = vol.Schema( SCHEMA = vol.Schema(
{ {
vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_HOST): vol.Coerce(str),
vol.Required(ATTR_PORT): NETWORK_PORT, vol.Required(ATTR_PORT): network_port,
vol.Optional(ATTR_USERNAME): vol.Coerce(str), vol.Optional(ATTR_USERNAME): vol.Coerce(str),
vol.Optional(ATTR_PASSWORD): vol.Coerce(str), vol.Optional(ATTR_PASSWORD): vol.Coerce(str),
vol.Optional(ATTR_SSL, default=False): vol.Boolean(), vol.Optional(ATTR_SSL, default=False): vol.Boolean(),

View File

@ -6,7 +6,7 @@ import voluptuous as vol
from ..const import ATTR_ADDON, ATTR_CONFIG, ATTR_DISCOVERY, ATTR_SERVICE, ATTR_UUID from ..const import ATTR_ADDON, ATTR_CONFIG, ATTR_DISCOVERY, ATTR_SERVICE, ATTR_UUID
from ..utils.validate import schema_or from ..utils.validate import schema_or
from ..validate import UUID_MATCH from ..validate import uuid_match
def valid_discovery_service(service): def valid_discovery_service(service):
@ -31,7 +31,7 @@ SCHEMA_DISCOVERY = vol.Schema(
[ [
vol.Schema( vol.Schema(
{ {
vol.Required(ATTR_UUID): UUID_MATCH, vol.Required(ATTR_UUID): uuid_match,
vol.Required(ATTR_ADDON): vol.Coerce(str), vol.Required(ATTR_ADDON): vol.Coerce(str),
vol.Required(ATTR_SERVICE): valid_discovery_service, vol.Required(ATTR_SERVICE): valid_discovery_service,
vol.Required(ATTR_CONFIG): vol.Maybe(dict), vol.Required(ATTR_CONFIG): vol.Maybe(dict),

View File

@ -17,7 +17,7 @@ from .docker.stats import DockerStats
from .exceptions import CoreDNSError, CoreDNSUpdateError, DockerAPIError from .exceptions import CoreDNSError, CoreDNSUpdateError, DockerAPIError
from .misc.forwarder import DNSForward from .misc.forwarder import DNSForward
from .utils.json import JsonConfig from .utils.json import JsonConfig
from .validate import DNS_URL, SCHEMA_DNS_CONFIG from .validate import dns_url, SCHEMA_DNS_CONFIG
_LOGGER: logging.Logger = logging.getLogger(__name__) _LOGGER: logging.Logger = logging.getLogger(__name__)
@ -115,7 +115,6 @@ class CoreDNS(JsonConfig, CoreSysAttributes):
# Start DNS forwarder # Start DNS forwarder
self.sys_create_task(self.forwarder.start(self.sys_docker.network.dns)) self.sys_create_task(self.forwarder.start(self.sys_docker.network.dns))
self._update_local_resolv()
# Reset container configuration # Reset container configuration
if await self.instance.is_running(): if await self.instance.is_running():
@ -218,9 +217,15 @@ class CoreDNS(JsonConfig, CoreSysAttributes):
# Prepare DNS serverlist: Prio 1 Local, Prio 2 Manual, Prio 3 Fallback # Prepare DNS serverlist: Prio 1 Local, Prio 2 Manual, Prio 3 Fallback
local_dns: List[str] = self.sys_host.network.dns_servers or ["dns://127.0.0.11"] local_dns: List[str] = self.sys_host.network.dns_servers or ["dns://127.0.0.11"]
_LOGGER.debug(
"local-dns = %s, config-dns = %s, backup-dns = %s",
local_dns,
self.servers,
DNS_SERVERS,
)
for server in local_dns + self.servers + DNS_SERVERS: for server in local_dns + self.servers + DNS_SERVERS:
try: try:
DNS_URL(server) dns_url(server)
if server not in dns_servers: if server not in dns_servers:
dns_servers.append(server) dns_servers.append(server)
except vol.Invalid: except vol.Invalid:
@ -346,33 +351,3 @@ class CoreDNS(JsonConfig, CoreSysAttributes):
await self.instance.install(self.version) await self.instance.install(self.version)
except DockerAPIError: except DockerAPIError:
_LOGGER.error("Repairing of CoreDNS fails") _LOGGER.error("Repairing of CoreDNS fails")
def _update_local_resolv(self) -> None:
"""Update local resolv file."""
resolv_lines: List[str] = []
nameserver = f"nameserver {self.sys_docker.network.dns!s}"
# Read resolv config
try:
with RESOLV_CONF.open("r") as resolv:
for line in resolv.readlines():
if not line:
continue
resolv_lines.append(line.strip())
except OSError as err:
_LOGGER.warning("Can't read local resolv: %s", err)
return
if nameserver in resolv_lines:
return
_LOGGER.info("Update resolv from Supervisor")
# Write config back to resolv
resolv_lines.append(nameserver)
try:
with RESOLV_CONF.open("w") as resolv:
for line in resolv_lines:
resolv.write(f"{line}\n")
except OSError as err:
_LOGGER.warning("Can't write local resolv: %s", err)
return

View File

@ -54,4 +54,9 @@ class DockerDNS(DockerInterface, CoreSysAttributes):
) )
self._meta = docker_container.attrs self._meta = docker_container.attrs
_LOGGER.info("Start DNS %s with version %s", self.image, self.version) _LOGGER.info(
"Start DNS %s with version %s - %s",
self.image,
self.version,
self.sys_docker.network.dns,
)

View File

@ -4,7 +4,7 @@ from typing import Any, Dict
from hassio.addons.addon import Addon from hassio.addons.addon import Addon
from hassio.exceptions import ServicesError from hassio.exceptions import ServicesError
from hassio.validate import NETWORK_PORT from hassio.validate import network_port
import voluptuous as vol import voluptuous as vol
from ..const import ( from ..const import (
@ -26,7 +26,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
SCHEMA_SERVICE_MQTT = vol.Schema( SCHEMA_SERVICE_MQTT = vol.Schema(
{ {
vol.Required(ATTR_HOST): vol.Coerce(str), vol.Required(ATTR_HOST): vol.Coerce(str),
vol.Required(ATTR_PORT): NETWORK_PORT, vol.Required(ATTR_PORT): network_port,
vol.Optional(ATTR_USERNAME): vol.Coerce(str), vol.Optional(ATTR_USERNAME): vol.Coerce(str),
vol.Optional(ATTR_PASSWORD): vol.Coerce(str), vol.Optional(ATTR_PASSWORD): vol.Coerce(str),
vol.Optional(ATTR_SSL, default=False): vol.Boolean(), vol.Optional(ATTR_SSL, default=False): vol.Boolean(),

View File

@ -31,7 +31,7 @@ from ..const import (
SNAPSHOT_FULL, SNAPSHOT_FULL,
SNAPSHOT_PARTIAL, SNAPSHOT_PARTIAL,
) )
from ..validate import DOCKER_IMAGE, NETWORK_PORT, REPOSITORIES from ..validate import docker_image, network_port, repositories
ALL_FOLDERS = [FOLDER_HOMEASSISTANT, FOLDER_SHARE, FOLDER_ADDONS, FOLDER_SSL] ALL_FOLDERS = [FOLDER_HOMEASSISTANT, FOLDER_SHARE, FOLDER_ADDONS, FOLDER_SSL]
@ -59,11 +59,11 @@ SCHEMA_SNAPSHOT = vol.Schema(
vol.Optional(ATTR_HOMEASSISTANT, default=dict): vol.Schema( vol.Optional(ATTR_HOMEASSISTANT, default=dict): vol.Schema(
{ {
vol.Optional(ATTR_VERSION): vol.Coerce(str), vol.Optional(ATTR_VERSION): vol.Coerce(str),
vol.Inclusive(ATTR_IMAGE, "custom_hass"): DOCKER_IMAGE, vol.Inclusive(ATTR_IMAGE, "custom_hass"): docker_image,
vol.Inclusive(ATTR_LAST_VERSION, "custom_hass"): vol.Coerce(str), vol.Inclusive(ATTR_LAST_VERSION, "custom_hass"): vol.Coerce(str),
vol.Optional(ATTR_BOOT, default=True): vol.Boolean(), vol.Optional(ATTR_BOOT, default=True): vol.Boolean(),
vol.Optional(ATTR_SSL, default=False): vol.Boolean(), vol.Optional(ATTR_SSL, default=False): vol.Boolean(),
vol.Optional(ATTR_PORT, default=8123): NETWORK_PORT, vol.Optional(ATTR_PORT, default=8123): network_port,
vol.Optional(ATTR_PASSWORD): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_PASSWORD): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_REFRESH_TOKEN): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_REFRESH_TOKEN): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_WATCHDOG, default=True): vol.Boolean(), vol.Optional(ATTR_WATCHDOG, default=True): vol.Boolean(),
@ -90,7 +90,7 @@ SCHEMA_SNAPSHOT = vol.Schema(
], ],
unique_addons, unique_addons,
), ),
vol.Optional(ATTR_REPOSITORIES, default=list): REPOSITORIES, vol.Optional(ATTR_REPOSITORIES, default=list): repositories,
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )

View File

@ -1,6 +1,7 @@
"""Validate functions.""" """Validate functions."""
import re import re
import uuid import uuid
import ipaddress
import voluptuous as vol import voluptuous as vol
@ -35,27 +36,41 @@ from .const import (
CHANNEL_BETA, CHANNEL_BETA,
CHANNEL_DEV, CHANNEL_DEV,
CHANNEL_STABLE, CHANNEL_STABLE,
DNS_SERVERS,
) )
from .utils.validate import validate_timezone from .utils.validate import validate_timezone
RE_REPOSITORY = re.compile(r"^(?P<url>[^#]+)(?:#(?P<branch>[\w\-]+))?$") RE_REPOSITORY = re.compile(r"^(?P<url>[^#]+)(?:#(?P<branch>[\w\-]+))?$")
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
NETWORK_PORT = vol.All(vol.Coerce(int), vol.Range(min=1, max=65535)) # pylint: disable=invalid-name
WAIT_BOOT = vol.All(vol.Coerce(int), vol.Range(min=1, max=60)) network_port = vol.All(vol.Coerce(int), vol.Range(min=1, max=65535))
DOCKER_IMAGE = vol.Match(r"^[\w{}]+/[\-\w{}]+$") wait_boot = vol.All(vol.Coerce(int), vol.Range(min=1, max=60))
ALSA_DEVICE = vol.Maybe(vol.Match(r"\d+,\d+")) docker_image = vol.Match(r"^[\w{}]+/[\-\w{}]+$")
CHANNELS = vol.In([CHANNEL_STABLE, CHANNEL_BETA, CHANNEL_DEV]) alsa_device = vol.Maybe(vol.Match(r"\d+,\d+"))
UUID_MATCH = vol.Match(r"^[0-9a-f]{32}$") channels = vol.In([CHANNEL_STABLE, CHANNEL_BETA, CHANNEL_DEV])
SHA256 = vol.Match(r"^[0-9a-f]{64}$") uuid_match = vol.Match(r"^[0-9a-f]{32}$")
TOKEN = vol.Match(r"^[0-9a-f]{32,256}$") sha256 = vol.Match(r"^[0-9a-f]{64}$")
LOG_LEVEL = vol.In(["debug", "info", "warning", "error", "critical"]) token = vol.Match(r"^[0-9a-f]{32,256}$")
DNS_URL = vol.Match(r"^dns://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$") log_level = vol.In(["debug", "info", "warning", "error", "critical"])
DNS_SERVER_LIST = vol.All([DNS_URL], vol.Length(max=8))
def validate_repository(repository): def dns_url(url: str) -> str:
""" takes a DNS url (str) and validates that it matches the scheme dns://<ip address>."""
if not url.lower().startswith("dns://"):
raise vol.Invalid("Doesn't start with dns://")
address: str = url[6:] # strip the dns:// off
try:
ipaddress.ip_address(address) # matches ipv4 or ipv6 addresses
except ValueError:
raise vol.Invalid("Invalid DNS URL: {}".format(url))
return url
dns_server_list = vol.All(vol.Length(max=8), [dns_url])
def validate_repository(repository: str) -> str:
"""Validate a valid repository.""" """Validate a valid repository."""
data = RE_REPOSITORY.match(repository) data = RE_REPOSITORY.match(repository)
if not data: if not data:
@ -69,13 +84,13 @@ def validate_repository(repository):
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
REPOSITORIES = vol.All([validate_repository], vol.Unique()) repositories = vol.All([validate_repository], vol.Unique())
DOCKER_PORTS = vol.Schema( DOCKER_PORTS = vol.Schema(
{ {
vol.All(vol.Coerce(str), vol.Match(r"^\d+(?:/tcp|/udp)?$")): vol.Maybe( vol.All(vol.Coerce(str), vol.Match(r"^\d+(?:/tcp|/udp)?$")): vol.Maybe(
NETWORK_PORT network_port
) )
} }
) )
@ -88,13 +103,13 @@ DOCKER_PORTS_DESCRIPTION = vol.Schema(
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
SCHEMA_HASS_CONFIG = vol.Schema( SCHEMA_HASS_CONFIG = vol.Schema(
{ {
vol.Optional(ATTR_UUID, default=lambda: uuid.uuid4().hex): UUID_MATCH, vol.Optional(ATTR_UUID, default=lambda: uuid.uuid4().hex): uuid_match,
vol.Optional(ATTR_VERSION): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_VERSION): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_ACCESS_TOKEN): TOKEN, vol.Optional(ATTR_ACCESS_TOKEN): token,
vol.Optional(ATTR_BOOT, default=True): vol.Boolean(), vol.Optional(ATTR_BOOT, default=True): vol.Boolean(),
vol.Inclusive(ATTR_IMAGE, "custom_hass"): DOCKER_IMAGE, vol.Inclusive(ATTR_IMAGE, "custom_hass"): docker_image,
vol.Inclusive(ATTR_LAST_VERSION, "custom_hass"): vol.Coerce(str), vol.Inclusive(ATTR_LAST_VERSION, "custom_hass"): vol.Coerce(str),
vol.Optional(ATTR_PORT, default=8123): NETWORK_PORT, vol.Optional(ATTR_PORT, default=8123): network_port,
vol.Optional(ATTR_PASSWORD): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_PASSWORD): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_REFRESH_TOKEN): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_REFRESH_TOKEN): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_SSL, default=False): vol.Boolean(), vol.Optional(ATTR_SSL, default=False): vol.Boolean(),
@ -109,7 +124,7 @@ SCHEMA_HASS_CONFIG = vol.Schema(
SCHEMA_UPDATER_CONFIG = vol.Schema( SCHEMA_UPDATER_CONFIG = vol.Schema(
{ {
vol.Optional(ATTR_CHANNEL, default=CHANNEL_STABLE): CHANNELS, vol.Optional(ATTR_CHANNEL, default=CHANNEL_STABLE): channels,
vol.Optional(ATTR_HOMEASSISTANT): vol.Coerce(str), vol.Optional(ATTR_HOMEASSISTANT): vol.Coerce(str),
vol.Optional(ATTR_HASSIO): vol.Coerce(str), vol.Optional(ATTR_HASSIO): vol.Coerce(str),
vol.Optional(ATTR_HASSOS): vol.Coerce(str), vol.Optional(ATTR_HASSOS): vol.Coerce(str),
@ -128,9 +143,9 @@ SCHEMA_HASSIO_CONFIG = vol.Schema(
vol.Optional( vol.Optional(
ATTR_ADDONS_CUSTOM_LIST, ATTR_ADDONS_CUSTOM_LIST,
default=["https://github.com/hassio-addons/repository"], default=["https://github.com/hassio-addons/repository"],
): REPOSITORIES, ): repositories,
vol.Optional(ATTR_WAIT_BOOT, default=5): WAIT_BOOT, vol.Optional(ATTR_WAIT_BOOT, default=5): wait_boot,
vol.Optional(ATTR_LOGGING, default="info"): LOG_LEVEL, vol.Optional(ATTR_LOGGING, default="info"): log_level,
vol.Optional(ATTR_DEBUG, default=False): vol.Boolean(), vol.Optional(ATTR_DEBUG, default=False): vol.Boolean(),
vol.Optional(ATTR_DEBUG_BLOCK, default=False): vol.Boolean(), vol.Optional(ATTR_DEBUG_BLOCK, default=False): vol.Boolean(),
}, },
@ -138,16 +153,16 @@ SCHEMA_HASSIO_CONFIG = vol.Schema(
) )
SCHEMA_AUTH_CONFIG = vol.Schema({SHA256: SHA256}) SCHEMA_AUTH_CONFIG = vol.Schema({sha256: sha256})
SCHEMA_INGRESS_CONFIG = vol.Schema( SCHEMA_INGRESS_CONFIG = vol.Schema(
{ {
vol.Required(ATTR_SESSION, default=dict): vol.Schema( vol.Required(ATTR_SESSION, default=dict): vol.Schema(
{TOKEN: vol.Coerce(float)} {token: vol.Coerce(float)}
), ),
vol.Required(ATTR_PORTS, default=dict): vol.Schema( vol.Required(ATTR_PORTS, default=dict): vol.Schema(
{vol.Coerce(str): NETWORK_PORT} {vol.Coerce(str): network_port}
), ),
}, },
extra=vol.REMOVE_EXTRA, extra=vol.REMOVE_EXTRA,
@ -157,7 +172,7 @@ SCHEMA_INGRESS_CONFIG = vol.Schema(
SCHEMA_DNS_CONFIG = vol.Schema( SCHEMA_DNS_CONFIG = vol.Schema(
{ {
vol.Optional(ATTR_VERSION): vol.Maybe(vol.Coerce(str)), vol.Optional(ATTR_VERSION): vol.Maybe(vol.Coerce(str)),
vol.Optional(ATTR_SERVERS, default=DNS_SERVERS): DNS_SERVER_LIST, vol.Optional(ATTR_SERVERS, default=list): dns_server_list,
}, },
extra=vol.REMOVE_EXTRA, extra=vol.REMOVE_EXTRA,
) )

67
tests/test_validate.py Normal file
View File

@ -0,0 +1,67 @@
"""Test validators."""
import hassio.validate
import voluptuous.error
import pytest
GOOD_V4 = [
"dns://10.0.0.1", # random local
"dns://254.254.254.254", # random high numbers
"DNS://1.1.1.1", # cloudflare
"dns://9.9.9.9", # quad-9
]
GOOD_V6 = [
"dns://2606:4700:4700::1111", # cloudflare
"DNS://2606:4700:4700::1001", # cloudflare
]
BAD = ["hello world", "https://foo.bar", "", "dns://example.com"]
async def test_dns_url_v4_good():
""" tests the DNS validator with known-good ipv6 DNS URLs """
for url in GOOD_V4:
assert hassio.validate.dns_url(url)
async def test_dns_url_v6_good():
""" tests the DNS validator with known-good ipv6 DNS URLs """
for url in GOOD_V6:
assert hassio.validate.dns_url(url)
async def test_dns_server_list_v4():
""" test a list with v4 addresses """
assert hassio.validate.dns_server_list(GOOD_V4)
async def test_dns_server_list_v6():
""" test a list with v6 addresses """
assert hassio.validate.dns_server_list(GOOD_V6)
async def test_dns_server_list_combined():
""" test a list with both v4 and v6 addresses """
combined = GOOD_V4 + GOOD_V6
# test the matches
assert hassio.validate.dns_server_list(combined)
# test max_length is OK still
assert hassio.validate.dns_server_list(combined)
# test that it fails when the list is too long
with pytest.raises(voluptuous.error.Invalid):
hassio.validate.dns_server_list(combined + combined + combined + combined)
async def test_dns_server_list_bad():
""" test the bad list """
# test the matches
with pytest.raises(voluptuous.error.Invalid):
assert hassio.validate.dns_server_list(BAD)
async def test_dns_server_list_bad_combined():
""" test the bad list, combined with the good """
combined = GOOD_V4 + GOOD_V6 + BAD
with pytest.raises(voluptuous.error.Invalid):
# bad list
assert hassio.validate.dns_server_list(combined)