Add new network apis to reduce code duplication (#54832)

This commit is contained in:
J. Nick Koston 2021-08-18 12:33:26 -05:00 committed by GitHub
parent 30564d59b6
commit 6d0ce814e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 51 deletions

View File

@ -1,13 +1,14 @@
"""The Network Configuration integration.""" """The Network Configuration integration."""
from __future__ import annotations from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.components.websocket_api.connection import ActiveConnection
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
@ -45,6 +46,35 @@ async def async_get_source_ip(hass: HomeAssistant, target_ip: str) -> str:
return source_ip if source_ip in all_ipv4s else all_ipv4s[0] return source_ip if source_ip in all_ipv4s else all_ipv4s[0]
@bind_hass
async def async_get_enabled_source_ips(
hass: HomeAssistant,
) -> list[IPv4Address | IPv6Address]:
"""Build the list of enabled source ips."""
adapters = await async_get_adapters(hass)
sources: list[IPv4Address | IPv6Address] = []
for adapter in adapters:
if not adapter["enabled"]:
continue
if adapter["ipv4"]:
sources.extend(IPv4Address(ipv4["address"]) for ipv4 in adapter["ipv4"])
if adapter["ipv6"]:
# With python 3.9 add scope_ids can be
# added by enumerating adapter["ipv6"]s
# IPv6Address(f"::%{ipv6['scope_id']}")
sources.extend(IPv6Address(ipv6["address"]) for ipv6 in adapter["ipv6"])
return sources
@callback
def async_only_default_interface_enabled(adapters: list[Adapter]) -> bool:
"""Check to see if any non-default adapter is enabled."""
return not any(
adapter["enabled"] and not adapter["default"] for adapter in adapters
)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up network for Home Assistant.""" """Set up network for Home Assistant."""

View File

@ -116,14 +116,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return True return True
@core_callback
def _async_use_default_interface(adapters: list[network.Adapter]) -> bool:
for adapter in adapters:
if adapter["enabled"] and not adapter["default"]:
return False
return True
@core_callback @core_callback
def _async_process_callbacks( def _async_process_callbacks(
callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str] callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str]
@ -204,24 +196,16 @@ class Scanner:
"""Build the list of ssdp sources.""" """Build the list of ssdp sources."""
adapters = await network.async_get_adapters(self.hass) adapters = await network.async_get_adapters(self.hass)
sources: set[IPv4Address | IPv6Address] = set() sources: set[IPv4Address | IPv6Address] = set()
if _async_use_default_interface(adapters): if network.async_only_default_interface_enabled(adapters):
sources.add(IPv4Address("0.0.0.0")) sources.add(IPv4Address("0.0.0.0"))
return sources return sources
for adapter in adapters: return {
if not adapter["enabled"]: source_ip
continue for source_ip in await network.async_get_enabled_source_ips(self.hass)
if adapter["ipv4"]: if not source_ip.is_loopback
ipv4 = adapter["ipv4"][0] and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
sources.add(IPv4Address(ipv4["address"])) }
if adapter["ipv6"]:
ipv6 = adapter["ipv6"][0]
# With python 3.9 add scope_ids can be
# added by enumerating adapter["ipv6"]s
# IPv6Address(f"::%{ipv6['scope_id']}")
sources.add(IPv6Address(ipv6["address"]))
return sources
async def async_scan(self, *_: Any) -> None: async def async_scan(self, *_: Any) -> None:
"""Scan for new entries using ssdp default and broadcast target.""" """Scan for new entries using ssdp default and broadcast target."""

View File

@ -5,7 +5,7 @@ import asyncio
from collections.abc import Coroutine from collections.abc import Coroutine
from contextlib import suppress from contextlib import suppress
import fnmatch import fnmatch
import ipaddress from ipaddress import IPv6Address, ip_address
import logging import logging
import socket import socket
from typing import Any, TypedDict, cast from typing import Any, TypedDict, cast
@ -131,13 +131,6 @@ async def _async_get_instance(hass: HomeAssistant, **zcargs: Any) -> HaAsyncZero
return aio_zc return aio_zc
def _async_use_default_interface(adapters: list[Adapter]) -> bool:
for adapter in adapters:
if adapter["enabled"] and not adapter["default"]:
return False
return True
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up Zeroconf and make Home Assistant discoverable.""" """Set up Zeroconf and make Home Assistant discoverable."""
zc_args: dict = {} zc_args: dict = {}
@ -151,25 +144,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
else: else:
zc_args["ip_version"] = IPVersion.All zc_args["ip_version"] = IPVersion.All
if not ipv6 and _async_use_default_interface(adapters): if not ipv6 and network.async_only_default_interface_enabled(adapters):
zc_args["interfaces"] = InterfaceChoice.Default zc_args["interfaces"] = InterfaceChoice.Default
else: else:
interfaces = zc_args["interfaces"] = [] zc_args["interfaces"] = [
for adapter in adapters: str(source_ip)
if not adapter["enabled"]: for source_ip in await network.async_get_enabled_source_ips(hass)
continue if not source_ip.is_loopback
if ipv4s := adapter["ipv4"]: and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
interfaces.extend( ]
ipv4["address"]
for ipv4 in ipv4s
if not ipaddress.IPv4Address(ipv4["address"]).is_loopback
)
if ipv6s := adapter["ipv6"]:
for ipv6_addr in ipv6s:
address = ipv6_addr["address"]
v6_ip_address = ipaddress.IPv6Address(address)
if not v6_ip_address.is_global and not v6_ip_address.is_loopback:
interfaces.append(ipv6_addr["address"])
aio_zc = await _async_get_instance(hass, **zc_args) aio_zc = await _async_get_instance(hass, **zc_args)
zeroconf = cast(HaZeroconf, aio_zc.zeroconf) zeroconf = cast(HaZeroconf, aio_zc.zeroconf)
@ -213,7 +196,7 @@ def _get_announced_addresses(
addresses = { addresses = {
addr.packed addr.packed
for addr in [ for addr in [
ipaddress.ip_address(ip["address"]) ip_address(ip["address"])
for adapter in adapters for adapter in adapters
if adapter["enabled"] if adapter["enabled"]
for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"]) for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"])
@ -530,7 +513,7 @@ def info_from_service(service: AsyncServiceInfo) -> HaServiceInfo | None:
address = service.addresses[0] address = service.addresses[0]
return { return {
"host": str(ipaddress.ip_address(address)), "host": str(ip_address(address)),
"port": service.port, "port": service.port,
"hostname": service.server, "hostname": service.server,
"type": service.type, "type": service.type,