Add is_host_valid util (#76589)

This commit is contained in:
Artem Draft 2022-09-11 19:12:04 +03:00 committed by GitHub
parent b0777e6280
commit 29be6d17b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 56 deletions

View File

@ -1,9 +1,6 @@
"""Config flow to configure the Bravia TV integration.""" """Config flow to configure the Bravia TV integration."""
from __future__ import annotations from __future__ import annotations
from contextlib import suppress
import ipaddress
import re
from typing import Any from typing import Any
from aiohttp import CookieJar from aiohttp import CookieJar
@ -17,6 +14,7 @@ from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_create_clientsession from homeassistant.helpers.aiohttp_client import async_create_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util.network import is_host_valid
from . import BraviaTVCoordinator from . import BraviaTVCoordinator
from .const import ( from .const import (
@ -30,15 +28,6 @@ from .const import (
) )
def host_valid(host: str) -> bool:
"""Return True if hostname or IP address is valid."""
with suppress(ValueError):
if ipaddress.ip_address(host).version in [4, 6]:
return True
disallowed = re.compile(r"[^a-zA-Z\d\-]")
return all(x and not disallowed.search(x) for x in host.split("."))
class BraviaTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class BraviaTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Bravia TV integration.""" """Handle a config flow for Bravia TV integration."""
@ -82,7 +71,7 @@ class BraviaTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
host = user_input[CONF_HOST] host = user_input[CONF_HOST]
if host_valid(host): if is_host_valid(host):
session = async_create_clientsession( session = async_create_clientsession(
self.hass, self.hass,
cookie_jar=CookieJar(unsafe=True, quote_cookie=False), cookie_jar=CookieJar(unsafe=True, quote_cookie=False),

View File

@ -1,8 +1,6 @@
"""Adds config flow for Brother Printer.""" """Adds config flow for Brother Printer."""
from __future__ import annotations from __future__ import annotations
import ipaddress
import re
from typing import Any from typing import Any
from brother import Brother, SnmpError, UnsupportedModel from brother import Brother, SnmpError, UnsupportedModel
@ -12,6 +10,7 @@ from homeassistant import config_entries, exceptions
from homeassistant.components import zeroconf from homeassistant.components import zeroconf
from homeassistant.const import CONF_HOST, CONF_TYPE from homeassistant.const import CONF_HOST, CONF_TYPE
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.util.network import is_host_valid
from .const import DOMAIN, PRINTER_TYPES from .const import DOMAIN, PRINTER_TYPES
from .utils import get_snmp_engine from .utils import get_snmp_engine
@ -24,17 +23,6 @@ DATA_SCHEMA = vol.Schema(
) )
def host_valid(host: str) -> bool:
"""Return True if hostname or IP address is valid."""
try:
if ipaddress.ip_address(host).version in [4, 6]:
return True
except ValueError:
pass
disallowed = re.compile(r"[^a-zA-Z\d\-]")
return all(x and not disallowed.search(x) for x in host.split("."))
class BrotherConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class BrotherConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Brother Printer.""" """Handle a config flow for Brother Printer."""
@ -53,7 +41,7 @@ class BrotherConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
try: try:
if not host_valid(user_input[CONF_HOST]): if not is_host_valid(user_input[CONF_HOST]):
raise InvalidHost() raise InvalidHost()
snmp_engine = get_snmp_engine(self.hass) snmp_engine = get_snmp_engine(self.hass)

View File

@ -1,8 +1,6 @@
"""Adds config flow for Dune HD integration.""" """Adds config flow for Dune HD integration."""
from __future__ import annotations from __future__ import annotations
import ipaddress
import re
from typing import Any from typing import Any
from pdunehd import DuneHDPlayer from pdunehd import DuneHDPlayer
@ -11,23 +9,11 @@ import voluptuous as vol
from homeassistant import config_entries, exceptions from homeassistant import config_entries, exceptions
from homeassistant.const import CONF_HOST from homeassistant.const import CONF_HOST
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.util.network import is_host_valid
from .const import DOMAIN from .const import DOMAIN
def host_valid(host: str) -> bool:
"""Return True if hostname or IP address is valid."""
try:
if ipaddress.ip_address(host).version in (4, 6):
return True
except ValueError:
pass
if len(host) > 253:
return False
allowed = re.compile(r"(?!-)[A-Z\d\-\_]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in host.split("."))
class DuneHDConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class DuneHDConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Dune HD integration.""" """Handle a config flow for Dune HD integration."""
@ -47,7 +33,7 @@ class DuneHDConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors = {} errors = {}
if user_input is not None: if user_input is not None:
if host_valid(user_input[CONF_HOST]): if is_host_valid(user_input[CONF_HOST]):
host: str = user_input[CONF_HOST] host: str = user_input[CONF_HOST]
try: try:

View File

@ -1,7 +1,5 @@
"""Config flow for Vilfo Router integration.""" """Config flow for Vilfo Router integration."""
import ipaddress
import logging import logging
import re
from vilfo import Client as VilfoClient from vilfo import Client as VilfoClient
from vilfo.exceptions import ( from vilfo.exceptions import (
@ -12,6 +10,7 @@ import voluptuous as vol
from homeassistant import config_entries, core, exceptions from homeassistant import config_entries, core, exceptions
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_HOST, CONF_ID, CONF_MAC from homeassistant.const import CONF_ACCESS_TOKEN, CONF_HOST, CONF_ID, CONF_MAC
from homeassistant.util.network import is_host_valid
from .const import DOMAIN, ROUTER_DEFAULT_HOST from .const import DOMAIN, ROUTER_DEFAULT_HOST
@ -29,16 +28,6 @@ RESULT_CANNOT_CONNECT = "cannot_connect"
RESULT_INVALID_AUTH = "invalid_auth" RESULT_INVALID_AUTH = "invalid_auth"
def host_valid(host):
"""Return True if hostname or IP address is valid."""
try:
if ipaddress.ip_address(host).version in (4, 6):
return True
except ValueError:
disallowed = re.compile(r"[^a-zA-Z\d\-]")
return all(x and not disallowed.search(x) for x in host.split("."))
def _try_connect_and_fetch_basic_info(host, token): def _try_connect_and_fetch_basic_info(host, token):
"""Attempt to connect and call the ping endpoint and, if successful, fetch basic information.""" """Attempt to connect and call the ping endpoint and, if successful, fetch basic information."""
@ -80,7 +69,7 @@ async def validate_input(hass: core.HomeAssistant, data):
""" """
# Validate the host before doing anything else. # Validate the host before doing anything else.
if not host_valid(data[CONF_HOST]): if not is_host_valid(data[CONF_HOST]):
raise InvalidHost raise InvalidHost
config = {} config = {}

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network from ipaddress import IPv4Address, IPv6Address, ip_address, ip_network
import re
import yarl import yarl
@ -86,6 +87,20 @@ def is_ipv6_address(address: str) -> bool:
return True return True
def is_host_valid(host: str) -> bool:
"""Check if a given string is an IP address or valid hostname."""
if is_ip_address(host):
return True
if len(host) > 255:
return False
if re.match(r"^[0-9\.]+$", host): # reject invalid IPv4
return False
if host.endswith("."): # dot at the end is correct
host = host[:-1]
allowed = re.compile(r"(?!-)[A-Z\d\-]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in host.split("."))
def normalize_url(address: str) -> str: def normalize_url(address: str) -> str:
"""Normalize a given URL.""" """Normalize a given URL."""
url = yarl.URL(address.rstrip("/")) url = yarl.URL(address.rstrip("/"))

View File

@ -80,6 +80,30 @@ def test_is_ipv6_address():
assert network_util.is_ipv6_address("8.8.8.8") is False assert network_util.is_ipv6_address("8.8.8.8") is False
def test_is_valid_host():
"""Test if strings are IPv6 addresses."""
assert network_util.is_host_valid("::1")
assert network_util.is_host_valid("::ffff:127.0.0.0")
assert network_util.is_host_valid("2001:0db8:85a3:0000:0000:8a2e:0370:7334")
assert network_util.is_host_valid("8.8.8.8")
assert network_util.is_host_valid("local")
assert network_util.is_host_valid("host-host")
assert network_util.is_host_valid("example.com")
assert network_util.is_host_valid("example.com.")
assert network_util.is_host_valid("Example123.com")
assert not network_util.is_host_valid("")
assert not network_util.is_host_valid("192.168.0.1:8080")
assert not network_util.is_host_valid("192.168.0.999")
assert not network_util.is_host_valid("2001:hb8::1:0:0:1")
assert not network_util.is_host_valid("-host-host")
assert not network_util.is_host_valid("host-host-")
assert not network_util.is_host_valid("host_host")
assert not network_util.is_host_valid("example.com/path")
assert not network_util.is_host_valid("example.com:8080")
assert not network_util.is_host_valid("verylonghostname" * 4)
assert not network_util.is_host_valid("verydeepdomain." * 18)
def test_normalize_url(): def test_normalize_url():
"""Test the normalizing of URLs.""" """Test the normalizing of URLs."""
assert network_util.normalize_url("http://example.com") == "http://example.com" assert network_util.normalize_url("http://example.com") == "http://example.com"