mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 01:07:10 +00:00
Add new network apis to reduce code duplication (#54832)
This commit is contained in:
parent
30564d59b6
commit
6d0ce814e7
@ -1,13 +1,14 @@
|
|||||||
"""The Network Configuration integration."""
|
"""The Network Configuration integration."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from ipaddress import IPv4Address, IPv6Address
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.websocket_api.connection import ActiveConnection
|
from homeassistant.components.websocket_api.connection import ActiveConnection
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
|
|
||||||
@ -45,6 +46,35 @@ async def async_get_source_ip(hass: HomeAssistant, target_ip: str) -> str:
|
|||||||
return source_ip if source_ip in all_ipv4s else all_ipv4s[0]
|
return source_ip if source_ip in all_ipv4s else all_ipv4s[0]
|
||||||
|
|
||||||
|
|
||||||
|
@bind_hass
|
||||||
|
async def async_get_enabled_source_ips(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
) -> list[IPv4Address | IPv6Address]:
|
||||||
|
"""Build the list of enabled source ips."""
|
||||||
|
adapters = await async_get_adapters(hass)
|
||||||
|
sources: list[IPv4Address | IPv6Address] = []
|
||||||
|
for adapter in adapters:
|
||||||
|
if not adapter["enabled"]:
|
||||||
|
continue
|
||||||
|
if adapter["ipv4"]:
|
||||||
|
sources.extend(IPv4Address(ipv4["address"]) for ipv4 in adapter["ipv4"])
|
||||||
|
if adapter["ipv6"]:
|
||||||
|
# With python 3.9 add scope_ids can be
|
||||||
|
# added by enumerating adapter["ipv6"]s
|
||||||
|
# IPv6Address(f"::%{ipv6['scope_id']}")
|
||||||
|
sources.extend(IPv6Address(ipv6["address"]) for ipv6 in adapter["ipv6"])
|
||||||
|
|
||||||
|
return sources
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_only_default_interface_enabled(adapters: list[Adapter]) -> bool:
|
||||||
|
"""Check to see if any non-default adapter is enabled."""
|
||||||
|
return not any(
|
||||||
|
adapter["enabled"] and not adapter["default"] for adapter in adapters
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up network for Home Assistant."""
|
"""Set up network for Home Assistant."""
|
||||||
|
|
||||||
|
@ -116,14 +116,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@core_callback
|
|
||||||
def _async_use_default_interface(adapters: list[network.Adapter]) -> bool:
|
|
||||||
for adapter in adapters:
|
|
||||||
if adapter["enabled"] and not adapter["default"]:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@core_callback
|
@core_callback
|
||||||
def _async_process_callbacks(
|
def _async_process_callbacks(
|
||||||
callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str]
|
callbacks: list[Callable[[dict], None]], discovery_info: dict[str, str]
|
||||||
@ -204,24 +196,16 @@ class Scanner:
|
|||||||
"""Build the list of ssdp sources."""
|
"""Build the list of ssdp sources."""
|
||||||
adapters = await network.async_get_adapters(self.hass)
|
adapters = await network.async_get_adapters(self.hass)
|
||||||
sources: set[IPv4Address | IPv6Address] = set()
|
sources: set[IPv4Address | IPv6Address] = set()
|
||||||
if _async_use_default_interface(adapters):
|
if network.async_only_default_interface_enabled(adapters):
|
||||||
sources.add(IPv4Address("0.0.0.0"))
|
sources.add(IPv4Address("0.0.0.0"))
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
for adapter in adapters:
|
return {
|
||||||
if not adapter["enabled"]:
|
source_ip
|
||||||
continue
|
for source_ip in await network.async_get_enabled_source_ips(self.hass)
|
||||||
if adapter["ipv4"]:
|
if not source_ip.is_loopback
|
||||||
ipv4 = adapter["ipv4"][0]
|
and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
|
||||||
sources.add(IPv4Address(ipv4["address"]))
|
}
|
||||||
if adapter["ipv6"]:
|
|
||||||
ipv6 = adapter["ipv6"][0]
|
|
||||||
# With python 3.9 add scope_ids can be
|
|
||||||
# added by enumerating adapter["ipv6"]s
|
|
||||||
# IPv6Address(f"::%{ipv6['scope_id']}")
|
|
||||||
sources.add(IPv6Address(ipv6["address"]))
|
|
||||||
|
|
||||||
return sources
|
|
||||||
|
|
||||||
async def async_scan(self, *_: Any) -> None:
|
async def async_scan(self, *_: Any) -> None:
|
||||||
"""Scan for new entries using ssdp default and broadcast target."""
|
"""Scan for new entries using ssdp default and broadcast target."""
|
||||||
|
@ -5,7 +5,7 @@ import asyncio
|
|||||||
from collections.abc import Coroutine
|
from collections.abc import Coroutine
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import ipaddress
|
from ipaddress import IPv6Address, ip_address
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from typing import Any, TypedDict, cast
|
from typing import Any, TypedDict, cast
|
||||||
@ -131,13 +131,6 @@ async def _async_get_instance(hass: HomeAssistant, **zcargs: Any) -> HaAsyncZero
|
|||||||
return aio_zc
|
return aio_zc
|
||||||
|
|
||||||
|
|
||||||
def _async_use_default_interface(adapters: list[Adapter]) -> bool:
|
|
||||||
for adapter in adapters:
|
|
||||||
if adapter["enabled"] and not adapter["default"]:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up Zeroconf and make Home Assistant discoverable."""
|
"""Set up Zeroconf and make Home Assistant discoverable."""
|
||||||
zc_args: dict = {}
|
zc_args: dict = {}
|
||||||
@ -151,25 +144,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||||||
else:
|
else:
|
||||||
zc_args["ip_version"] = IPVersion.All
|
zc_args["ip_version"] = IPVersion.All
|
||||||
|
|
||||||
if not ipv6 and _async_use_default_interface(adapters):
|
if not ipv6 and network.async_only_default_interface_enabled(adapters):
|
||||||
zc_args["interfaces"] = InterfaceChoice.Default
|
zc_args["interfaces"] = InterfaceChoice.Default
|
||||||
else:
|
else:
|
||||||
interfaces = zc_args["interfaces"] = []
|
zc_args["interfaces"] = [
|
||||||
for adapter in adapters:
|
str(source_ip)
|
||||||
if not adapter["enabled"]:
|
for source_ip in await network.async_get_enabled_source_ips(hass)
|
||||||
continue
|
if not source_ip.is_loopback
|
||||||
if ipv4s := adapter["ipv4"]:
|
and not (isinstance(source_ip, IPv6Address) and source_ip.is_global)
|
||||||
interfaces.extend(
|
]
|
||||||
ipv4["address"]
|
|
||||||
for ipv4 in ipv4s
|
|
||||||
if not ipaddress.IPv4Address(ipv4["address"]).is_loopback
|
|
||||||
)
|
|
||||||
if ipv6s := adapter["ipv6"]:
|
|
||||||
for ipv6_addr in ipv6s:
|
|
||||||
address = ipv6_addr["address"]
|
|
||||||
v6_ip_address = ipaddress.IPv6Address(address)
|
|
||||||
if not v6_ip_address.is_global and not v6_ip_address.is_loopback:
|
|
||||||
interfaces.append(ipv6_addr["address"])
|
|
||||||
|
|
||||||
aio_zc = await _async_get_instance(hass, **zc_args)
|
aio_zc = await _async_get_instance(hass, **zc_args)
|
||||||
zeroconf = cast(HaZeroconf, aio_zc.zeroconf)
|
zeroconf = cast(HaZeroconf, aio_zc.zeroconf)
|
||||||
@ -213,7 +196,7 @@ def _get_announced_addresses(
|
|||||||
addresses = {
|
addresses = {
|
||||||
addr.packed
|
addr.packed
|
||||||
for addr in [
|
for addr in [
|
||||||
ipaddress.ip_address(ip["address"])
|
ip_address(ip["address"])
|
||||||
for adapter in adapters
|
for adapter in adapters
|
||||||
if adapter["enabled"]
|
if adapter["enabled"]
|
||||||
for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"])
|
for ip in cast(list, adapter["ipv6"]) + cast(list, adapter["ipv4"])
|
||||||
@ -530,7 +513,7 @@ def info_from_service(service: AsyncServiceInfo) -> HaServiceInfo | None:
|
|||||||
address = service.addresses[0]
|
address = service.addresses[0]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"host": str(ipaddress.ip_address(address)),
|
"host": str(ip_address(address)),
|
||||||
"port": service.port,
|
"port": service.port,
|
||||||
"hostname": service.server,
|
"hostname": service.server,
|
||||||
"type": service.type,
|
"type": service.type,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user