diff --git a/supervisor/addons/validate.py b/supervisor/addons/validate.py index 246b18b46..68b232601 100644 --- a/supervisor/addons/validate.py +++ b/supervisor/addons/validate.py @@ -98,6 +98,7 @@ from ..validate import ( network_port, token, uuid_match, + version_tag, ) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -180,7 +181,7 @@ def _simple_startup(value) -> str: SCHEMA_ADDON_CONFIG = vol.Schema( { vol.Required(ATTR_NAME): vol.Coerce(str), - vol.Required(ATTR_VERSION): vol.Coerce(str), + vol.Required(ATTR_VERSION): vol.All(version_tag, str), vol.Required(ATTR_SLUG): vol.Coerce(str), vol.Required(ATTR_DESCRIPTON): vol.Coerce(str), vol.Required(ATTR_ARCH): [vol.In(ARCH_ALL)], diff --git a/supervisor/api/audio.py b/supervisor/api/audio.py index 535719764..3a1e46677 100644 --- a/supervisor/api/audio.py +++ b/supervisor/api/audio.py @@ -33,12 +33,12 @@ from ..const import ( from ..coresys import CoreSysAttributes from ..exceptions import APIError from ..host.sound import StreamType -from ..validate import simple_version +from ..validate import version_tag from .utils import api_process, api_process_raw, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): simple_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) SCHEMA_VOLUME = vol.Schema( { diff --git a/supervisor/api/cli.py b/supervisor/api/cli.py index 6bcb855ce..222defd4e 100644 --- a/supervisor/api/cli.py +++ b/supervisor/api/cli.py @@ -19,12 +19,12 @@ from ..const import ( ATTR_VERSION_LATEST, ) from ..coresys import CoreSysAttributes -from ..validate import simple_version +from ..validate import version_tag from .utils import api_process, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): simple_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) class APICli(CoreSysAttributes): diff --git a/supervisor/api/dns.py b/supervisor/api/dns.py index 21a43787e..cde71058b 100644 --- a/supervisor/api/dns.py +++ b/supervisor/api/dns.py @@ -24,7 +24,7 @@ from ..const import ( ) from ..coresys import CoreSysAttributes from ..exceptions import APIError -from ..validate import dns_server_list, simple_version +from ..validate import dns_server_list, version_tag from .utils import api_process, api_process_raw, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) # pylint: disable=no-value-for-parameter SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_SERVERS): dns_server_list}) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): simple_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) class APICoreDNS(CoreSysAttributes): diff --git a/supervisor/api/homeassistant.py b/supervisor/api/homeassistant.py index 52338fd4c..a300f4c9f 100644 --- a/supervisor/api/homeassistant.py +++ b/supervisor/api/homeassistant.py @@ -33,7 +33,7 @@ from ..const import ( ) from ..coresys import CoreSysAttributes from ..exceptions import APIError -from ..validate import complex_version, docker_image, network_port +from ..validate import docker_image, network_port, version_tag from .utils import api_process, api_process_raw, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ SCHEMA_OPTIONS = vol.Schema( } ) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): complex_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) class APIHomeAssistant(CoreSysAttributes): diff --git a/supervisor/api/multicast.py b/supervisor/api/multicast.py index 59f1607fd..8c52da036 100644 --- a/supervisor/api/multicast.py +++ b/supervisor/api/multicast.py @@ -21,12 +21,12 @@ from ..const import ( ) from ..coresys import CoreSysAttributes from ..exceptions import APIError -from ..validate import simple_version +from ..validate import version_tag from .utils import api_process, api_process_raw, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): simple_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) class APIMulticast(CoreSysAttributes): diff --git a/supervisor/api/os.py b/supervisor/api/os.py index dd90d4d72..63acf09ba 100644 --- a/supervisor/api/os.py +++ b/supervisor/api/os.py @@ -8,12 +8,12 @@ import voluptuous as vol from ..const import ATTR_BOARD, ATTR_BOOT, ATTR_VERSION, ATTR_VERSION_LATEST from ..coresys import CoreSysAttributes -from ..validate import complex_version +from ..validate import version_tag from .utils import api_process, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): complex_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) class APIOS(CoreSysAttributes): diff --git a/supervisor/api/supervisor.py b/supervisor/api/supervisor.py index b61856586..864cf997e 100644 --- a/supervisor/api/supervisor.py +++ b/supervisor/api/supervisor.py @@ -43,7 +43,7 @@ from ..const import ( from ..coresys import CoreSysAttributes from ..exceptions import APIError from ..utils.validate import validate_timezone -from ..validate import repositories, simple_version, wait_boot +from ..validate import repositories, version_tag, wait_boot from .utils import api_process, api_process_raw, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -61,7 +61,7 @@ SCHEMA_OPTIONS = vol.Schema( } ) -SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): simple_version}) +SCHEMA_VERSION = vol.Schema({vol.Optional(ATTR_VERSION): version_tag}) class APISupervisor(CoreSysAttributes): diff --git a/supervisor/plugins/validate.py b/supervisor/plugins/validate.py index b9a93dee5..451050d93 100644 --- a/supervisor/plugins/validate.py +++ b/supervisor/plugins/validate.py @@ -3,11 +3,11 @@ import voluptuous as vol from ..const import ATTR_ACCESS_TOKEN, ATTR_IMAGE, ATTR_SERVERS, ATTR_VERSION -from ..validate import dns_server_list, docker_image, simple_version, token +from ..validate import dns_server_list, docker_image, token, version_tag SCHEMA_DNS_CONFIG = vol.Schema( { - vol.Optional(ATTR_VERSION): simple_version, + vol.Optional(ATTR_VERSION): version_tag, vol.Optional(ATTR_IMAGE): docker_image, vol.Optional(ATTR_SERVERS, default=list): dns_server_list, }, @@ -16,17 +16,14 @@ SCHEMA_DNS_CONFIG = vol.Schema( SCHEMA_AUDIO_CONFIG = vol.Schema( - { - vol.Optional(ATTR_VERSION): simple_version, - vol.Optional(ATTR_IMAGE): docker_image, - }, + {vol.Optional(ATTR_VERSION): version_tag, vol.Optional(ATTR_IMAGE): docker_image}, extra=vol.REMOVE_EXTRA, ) SCHEMA_CLI_CONFIG = vol.Schema( { - vol.Optional(ATTR_VERSION): simple_version, + vol.Optional(ATTR_VERSION): version_tag, vol.Optional(ATTR_IMAGE): docker_image, vol.Optional(ATTR_ACCESS_TOKEN): token, }, @@ -35,9 +32,6 @@ SCHEMA_CLI_CONFIG = vol.Schema( SCHEMA_MULTICAST_CONFIG = vol.Schema( - { - vol.Optional(ATTR_VERSION): simple_version, - vol.Optional(ATTR_IMAGE): docker_image, - }, + {vol.Optional(ATTR_VERSION): version_tag, vol.Optional(ATTR_IMAGE): docker_image}, extra=vol.REMOVE_EXTRA, ) diff --git a/supervisor/snapshots/validate.py b/supervisor/snapshots/validate.py index c1898f352..b186ed9d4 100644 --- a/supervisor/snapshots/validate.py +++ b/supervisor/snapshots/validate.py @@ -31,7 +31,7 @@ from ..const import ( SNAPSHOT_FULL, SNAPSHOT_PARTIAL, ) -from ..validate import complex_version, docker_image, network_port, repositories +from ..validate import docker_image, network_port, repositories, version_tag ALL_FOLDERS = [FOLDER_HOMEASSISTANT, FOLDER_SHARE, FOLDER_ADDONS, FOLDER_SSL] @@ -58,7 +58,7 @@ SCHEMA_SNAPSHOT = vol.Schema( vol.Inclusive(ATTR_CRYPTO, "encrypted"): CRYPTO_AES128, vol.Optional(ATTR_HOMEASSISTANT, default=dict): vol.Schema( { - vol.Optional(ATTR_VERSION): complex_version, + vol.Optional(ATTR_VERSION): version_tag, vol.Optional(ATTR_IMAGE): docker_image, vol.Optional(ATTR_BOOT, default=True): vol.Boolean(), vol.Optional(ATTR_SSL, default=False): vol.Boolean(), diff --git a/supervisor/validate.py b/supervisor/validate.py index fd46755bd..b5ca1250a 100644 --- a/supervisor/validate.py +++ b/supervisor/validate.py @@ -54,25 +54,15 @@ sha256 = vol.Match(r"^[0-9a-f]{64}$") token = vol.Match(r"^[0-9a-f]{32,256}$") -def simple_version(value: Union[str, int, None]) -> Optional[str]: +def version_tag(value: Union[str, None, int, float]) -> Optional[str]: """Validate main version handling.""" - if not isinstance(value, (str, int)): - return None - elif isinstance(value, int): - return str(value) - elif value.isnumeric() or value == "dev": - return value - return None - - -def complex_version(value: Union[str, None]) -> Optional[str]: - """Validate main version handling.""" - if not isinstance(value, str): + if value is None: return None try: + value = str(value) pkg_version.parse(value) - except pkg_version.InvalidVersion: + except (pkg_version.InvalidVersion, TypeError): raise vol.Invalid(f"Invalid version format {value}") return value @@ -126,7 +116,7 @@ DOCKER_PORTS_DESCRIPTION = vol.Schema( SCHEMA_HASS_CONFIG = vol.Schema( { vol.Optional(ATTR_UUID, default=lambda: uuid.uuid4().hex): uuid_match, - vol.Optional(ATTR_VERSION): complex_version, + vol.Optional(ATTR_VERSION): version_tag, vol.Optional(ATTR_IMAGE): docker_image, vol.Optional(ATTR_ACCESS_TOKEN): token, vol.Optional(ATTR_BOOT, default=True): vol.Boolean(), @@ -149,13 +139,13 @@ SCHEMA_UPDATER_CONFIG = vol.Schema( vol.Optional(ATTR_CHANNEL, default=UpdateChannels.STABLE): vol.Coerce( UpdateChannels ), - vol.Optional(ATTR_HOMEASSISTANT): complex_version, - vol.Optional(ATTR_SUPERVISOR): simple_version, - vol.Optional(ATTR_HASSOS): complex_version, - vol.Optional(ATTR_CLI): simple_version, - vol.Optional(ATTR_DNS): simple_version, - vol.Optional(ATTR_AUDIO): simple_version, - vol.Optional(ATTR_MULTICAST): simple_version, + vol.Optional(ATTR_HOMEASSISTANT): vol.All(version_tag, str), + vol.Optional(ATTR_SUPERVISOR): vol.All(version_tag, str), + vol.Optional(ATTR_HASSOS): vol.All(version_tag, str), + vol.Optional(ATTR_CLI): vol.All(version_tag, str), + vol.Optional(ATTR_DNS): vol.All(version_tag, str), + vol.Optional(ATTR_AUDIO): vol.All(version_tag, str), + vol.Optional(ATTR_MULTICAST): vol.All(version_tag, str), vol.Optional(ATTR_IMAGE, default=dict): vol.Schema( { vol.Optional(ATTR_HOMEASSISTANT): docker_image, @@ -177,7 +167,7 @@ SCHEMA_SUPERVISOR_CONFIG = vol.Schema( { vol.Optional(ATTR_TIMEZONE, default="UTC"): validate_timezone, vol.Optional(ATTR_LAST_BOOT): vol.Coerce(str), - vol.Optional(ATTR_VERSION, default=SUPERVISOR_VERSION): simple_version, + vol.Optional(ATTR_VERSION, default=SUPERVISOR_VERSION): version_tag, vol.Optional( ATTR_ADDONS_CUSTOM_LIST, default=["https://github.com/hassio-addons/repository"], diff --git a/tests/test_validate.py b/tests/test_validate.py index 641786ba4..b09ae2b63 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -1,68 +1,88 @@ """Test validators.""" import pytest -import voluptuous.error +import voluptuous as vol -import supervisor.validate +from supervisor import validate -GOOD_V4 = [ +DNS_GOOD_V4 = [ "dns://10.0.0.1", # random local "dns://254.254.254.254", # random high numbers "DNS://1.1.1.1", # cloudflare "dns://9.9.9.9", # quad-9 ] -GOOD_V6 = [ +DNS_GOOD_V6 = [ "dns://2606:4700:4700::1111", # cloudflare "DNS://2606:4700:4700::1001", # cloudflare ] -BAD = ["hello world", "https://foo.bar", "", "dns://example.com"] +DNS_BAD = ["hello world", "https://foo.bar", "", "dns://example.com"] async def test_dns_url_v4_good(): """Test the DNS validator with known-good ipv6 DNS URLs.""" - for url in GOOD_V4: - assert supervisor.validate.dns_url(url) + for url in DNS_GOOD_V4: + assert validate.dns_url(url) def test_dns_url_v6_good(): """Test the DNS validator with known-good ipv6 DNS URLs.""" - for url in GOOD_V6: - assert supervisor.validate.dns_url(url) + for url in DNS_GOOD_V6: + assert validate.dns_url(url) def test_dns_server_list_v4(): """Test a list with v4 addresses.""" - assert supervisor.validate.dns_server_list(GOOD_V4) + assert validate.dns_server_list(DNS_GOOD_V4) def test_dns_server_list_v6(): """Test a list with v6 addresses.""" - assert supervisor.validate.dns_server_list(GOOD_V6) + assert validate.dns_server_list(DNS_GOOD_V6) def test_dns_server_list_combined(): """Test a list with both v4 and v6 addresses.""" - combined = GOOD_V4 + GOOD_V6 + combined = DNS_GOOD_V4 + DNS_GOOD_V6 # test the matches - assert supervisor.validate.dns_server_list(combined) + assert validate.dns_server_list(combined) # test max_length is OK still - assert supervisor.validate.dns_server_list(combined) + assert validate.dns_server_list(combined) # test that it fails when the list is too long - with pytest.raises(voluptuous.error.Invalid): - supervisor.validate.dns_server_list(combined + combined + combined + combined) + with pytest.raises(vol.error.Invalid): + validate.dns_server_list(combined + combined + combined + combined) def test_dns_server_list_bad(): """Test the bad list.""" # test the matches - with pytest.raises(voluptuous.error.Invalid): - assert supervisor.validate.dns_server_list(BAD) + with pytest.raises(vol.error.Invalid): + assert validate.dns_server_list(DNS_BAD) def test_dns_server_list_bad_combined(): """Test the bad list, combined with the good.""" - combined = GOOD_V4 + GOOD_V6 + BAD + combined = DNS_GOOD_V4 + DNS_GOOD_V6 + DNS_BAD - with pytest.raises(voluptuous.error.Invalid): + with pytest.raises(vol.error.Invalid): # bad list - assert supervisor.validate.dns_server_list(combined) + assert validate.dns_server_list(combined) + + +def test_version_complex(): + """Test version simple with good version.""" + for version in ( + "landingpage", + "1c002dd", + "1.1.1", + "1.0", + "0.150.1", + "0.150.1b1", + "0.150.1.dev20200715", + "1", + "alpine-5.4", + 1, + 1.1, + ): + assert validate.version_tag(version) == str(version) + + assert validate.version_tag(None) is None