diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 9c20cf4f58..6196e01760 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -38,7 +38,7 @@ import yaml from yaml.nodes import Node from esphome import const, platformio_api, yaml_util -from esphome.helpers import get_bool_env, mkdir_p +from esphome.helpers import get_bool_env, mkdir_p, sort_ip_addresses from esphome.storage_json import ( StorageJSON, archive_storage_path, @@ -336,7 +336,7 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): # Use the IP address if available but only # if the API is loaded and the device is online # since MQTT logging will not work otherwise - port = address_list[0] + port = sort_ip_addresses(address_list)[0] elif ( entry.address and ( @@ -347,7 +347,7 @@ class EsphomePortCommandWebSocket(EsphomeCommandWebSocket): and not isinstance(address_list, Exception) ): # If mdns is not available, try to use the DNS cache - port = address_list[0] + port = sort_ip_addresses(address_list)[0] return [ *DASHBOARD_COMMAND, diff --git a/esphome/helpers.py b/esphome/helpers.py index 8aae43c2bb..b649465d69 100644 --- a/esphome/helpers.py +++ b/esphome/helpers.py @@ -200,6 +200,45 @@ def resolve_ip_address(host, port): return res +def sort_ip_addresses(address_list: list[str]) -> list[str]: + """Takes a list of IP addresses in string form, e.g. from mDNS or MQTT, + and sorts them into the best order to actually try connecting to them. + + This is roughly based on RFC6724 but a lot simpler: First we choose + IPv6 addresses, then Legacy IP addresses, and lowest priority is + link-local IPv6 addresses that don't have a link specified (which + are useless, but mDNS does provide them in that form). Addresses + which cannot be parsed are silently dropped. + """ + import socket + + # First "resolve" all the IP addresses to getaddrinfo() tuples of the form + # (family, type, proto, canonname, sockaddr) + res: list[ + tuple[ + int, + int, + int, + Union[str, None], + Union[tuple[str, int], tuple[str, int, int, int]], + ] + ] = [] + for addr in address_list: + # This should always work as these are supposed to be IP addresses + try: + res += socket.getaddrinfo( + addr, 0, proto=socket.IPPROTO_TCP, flags=socket.AI_NUMERICHOST + ) + except OSError: + _LOGGER.info("Failed to parse IP address '%s'", addr) + + # Now use that information to sort them. + res.sort(key=addr_preference_) + + # Finally, turn the getaddrinfo() tuples back into plain hostnames. + return [socket.getnameinfo(r[4], socket.NI_NUMERICHOST)[0] for r in res] + + def get_bool_env(var, default=False): value = os.getenv(var, default) if isinstance(value, str): diff --git a/tests/unit_tests/test_helpers.py b/tests/unit_tests/test_helpers.py index 862320b09e..b353d1aa99 100644 --- a/tests/unit_tests/test_helpers.py +++ b/tests/unit_tests/test_helpers.py @@ -267,3 +267,13 @@ def test_sanitize(text, expected): actual = helpers.sanitize(text) assert actual == expected + + +@pytest.mark.parametrize( + "text, expected", + ((["127.0.0.1", "fe80::1", "2001::2"], ["2001::2", "127.0.0.1", "fe80::1"]),), +) +def test_sort_ip_addresses(text: list[str], expected: list[str]) -> None: + actual = helpers.sort_ip_addresses(text) + + assert actual == expected