Strict typing for dhcp (#67361)

This commit is contained in:
J. Nick Koston 2022-02-28 18:49:44 -10:00 committed by GitHub
parent 21ce441a97
commit 076fe97110
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 73 deletions

View File

@ -68,6 +68,7 @@ homeassistant.components.device_automation.*
homeassistant.components.device_tracker.* homeassistant.components.device_tracker.*
homeassistant.components.devolo_home_control.* homeassistant.components.devolo_home_control.*
homeassistant.components.devolo_home_network.* homeassistant.components.devolo_home_network.*
homeassistant.components.dhcp.*
homeassistant.components.dlna_dmr.* homeassistant.components.dlna_dmr.*
homeassistant.components.dnsip.* homeassistant.components.dnsip.*
homeassistant.components.dsmr.* homeassistant.components.dsmr.*

View File

@ -2,6 +2,9 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio
from collections.abc import Callable, Iterable
import contextlib
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
import fnmatch import fnmatch
@ -9,7 +12,7 @@ from ipaddress import ip_address as make_ip_address
import logging import logging
import os import os
import threading import threading
from typing import Any, Final from typing import TYPE_CHECKING, Any, Final, cast
from aiodiscover import DiscoverHosts from aiodiscover import DiscoverHosts
from aiodiscover.discovery import ( from aiodiscover.discovery import (
@ -51,12 +54,16 @@ from homeassistant.helpers.event import (
) )
from homeassistant.helpers.frame import report from homeassistant.helpers.frame import report
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import async_get_dhcp from homeassistant.loader import DHCPMatcher, async_get_dhcp
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
from homeassistant.util.network import is_invalid, is_link_local, is_loopback from homeassistant.util.network import is_invalid, is_link_local, is_loopback
from .const import DOMAIN from .const import DOMAIN
if TYPE_CHECKING:
from scapy.packet import Packet
from scapy.sendrecv import AsyncSniffer
FILTER = "udp and (port 67 or 68)" FILTER = "udp and (port 67 or 68)"
REQUESTED_ADDR = "requested_addr" REQUESTED_ADDR = "requested_addr"
MESSAGE_TYPE = "message-type" MESSAGE_TYPE = "message-type"
@ -115,7 +122,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
watchers: list[WatcherBase] = [] watchers: list[WatcherBase] = []
address_data: dict[str, dict[str, str]] = {} address_data: dict[str, dict[str, str]] = {}
integration_matchers = await async_get_dhcp(hass) integration_matchers = await async_get_dhcp(hass)
# For the passive classes we need to start listening # For the passive classes we need to start listening
# for state changes and connect the dispatchers before # for state changes and connect the dispatchers before
# everything else starts up or we will miss events # everything else starts up or we will miss events
@ -124,13 +130,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await passive_watcher.async_start() await passive_watcher.async_start()
watchers.append(passive_watcher) watchers.append(passive_watcher)
async def _initialize(_): async def _initialize(event: Event) -> None:
for active_cls in (DHCPWatcher, NetworkWatcher): for active_cls in (DHCPWatcher, NetworkWatcher):
active_watcher = active_cls(hass, address_data, integration_matchers) active_watcher = active_cls(hass, address_data, integration_matchers)
await active_watcher.async_start() await active_watcher.async_start()
watchers.append(active_watcher) watchers.append(active_watcher)
async def _async_stop(*_): async def _async_stop(event: Event) -> None:
for watcher in watchers: for watcher in watchers:
await watcher.async_stop() await watcher.async_stop()
@ -143,7 +149,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
class WatcherBase: class WatcherBase:
"""Base class for dhcp and device tracker watching.""" """Base class for dhcp and device tracker watching."""
def __init__(self, hass, address_data, integration_matchers): def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class.""" """Initialize class."""
super().__init__() super().__init__()
@ -152,11 +163,11 @@ class WatcherBase:
self._address_data = address_data self._address_data = address_data
@abstractmethod @abstractmethod
async def async_stop(self): async def async_stop(self) -> None:
"""Stop the watcher.""" """Stop the watcher."""
@abstractmethod @abstractmethod
async def async_start(self): async def async_start(self) -> None:
"""Start the watcher.""" """Start the watcher."""
def process_client(self, ip_address: str, hostname: str, mac_address: str) -> None: def process_client(self, ip_address: str, hostname: str, mac_address: str) -> None:
@ -197,8 +208,8 @@ class WatcherBase:
data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname}
self._address_data[ip_address] = data self._address_data[ip_address] = data
lowercase_hostname = data[HOSTNAME].lower() lowercase_hostname = hostname.lower()
uppercase_mac = data[MAC_ADDRESS].upper() uppercase_mac = mac_address.upper()
_LOGGER.debug( _LOGGER.debug(
"Processing updated address data for %s: mac=%s hostname=%s", "Processing updated address data for %s: mac=%s hostname=%s",
@ -218,22 +229,24 @@ class WatcherBase:
if entry := self.hass.config_entries.async_get_entry(entry_id): if entry := self.hass.config_entries.async_get_entry(entry_id):
device_domains.add(entry.domain) device_domains.add(entry.domain)
for entry in self._integration_matchers: for matcher in self._integration_matchers:
if entry.get(REGISTERED_DEVICES) and not entry["domain"] in device_domains: domain = matcher["domain"]
if matcher.get(REGISTERED_DEVICES) and domain not in device_domains:
continue continue
if MAC_ADDRESS in entry and not fnmatch.fnmatch( if (
uppercase_mac, entry[MAC_ADDRESS] matcher_mac := matcher.get(MAC_ADDRESS)
): ) is not None and not fnmatch.fnmatch(uppercase_mac, matcher_mac):
continue continue
if HOSTNAME in entry and not fnmatch.fnmatch( if (
lowercase_hostname, entry[HOSTNAME] matcher_hostname := matcher.get(HOSTNAME)
): ) is not None and not fnmatch.fnmatch(lowercase_hostname, matcher_hostname):
continue continue
_LOGGER.debug("Matched %s against %s", data, entry) _LOGGER.debug("Matched %s against %s", data, matcher)
matched_domains.add(entry["domain"]) matched_domains.add(domain)
for domain in matched_domains: for domain in matched_domains:
discovery_flow.async_create_flow( discovery_flow.async_create_flow(
@ -243,7 +256,7 @@ class WatcherBase:
DhcpServiceInfo( DhcpServiceInfo(
ip=ip_address, ip=ip_address,
hostname=lowercase_hostname, hostname=lowercase_hostname,
macaddress=data[MAC_ADDRESS], macaddress=mac_address,
), ),
) )
@ -251,14 +264,19 @@ class WatcherBase:
class NetworkWatcher(WatcherBase): class NetworkWatcher(WatcherBase):
"""Class to query ptr records routers.""" """Class to query ptr records routers."""
def __init__(self, hass, address_data, integration_matchers): def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class.""" """Initialize class."""
super().__init__(hass, address_data, integration_matchers) super().__init__(hass, address_data, integration_matchers)
self._unsub = None self._unsub: Callable[[], None] | None = None
self._discover_hosts = None self._discover_hosts: DiscoverHosts | None = None
self._discover_task = None self._discover_task: asyncio.Task | None = None
async def async_stop(self): async def async_stop(self) -> None:
"""Stop scanning for new devices on the network.""" """Stop scanning for new devices on the network."""
if self._unsub: if self._unsub:
self._unsub() self._unsub()
@ -267,7 +285,7 @@ class NetworkWatcher(WatcherBase):
self._discover_task.cancel() self._discover_task.cancel()
self._discover_task = None self._discover_task = None
async def async_start(self): async def async_start(self) -> None:
"""Start scanning for new devices on the network.""" """Start scanning for new devices on the network."""
self._discover_hosts = DiscoverHosts() self._discover_hosts = DiscoverHosts()
self._unsub = async_track_time_interval( self._unsub = async_track_time_interval(
@ -276,14 +294,15 @@ class NetworkWatcher(WatcherBase):
self.async_start_discover() self.async_start_discover()
@callback @callback
def async_start_discover(self, *_): def async_start_discover(self, *_: Any) -> None:
"""Start a new discovery task if one is not running.""" """Start a new discovery task if one is not running."""
if self._discover_task and not self._discover_task.done(): if self._discover_task and not self._discover_task.done():
return return
self._discover_task = self.hass.async_create_task(self.async_discover()) self._discover_task = self.hass.async_create_task(self.async_discover())
async def async_discover(self): async def async_discover(self) -> None:
"""Process discovery.""" """Process discovery."""
assert self._discover_hosts is not None
for host in await self._discover_hosts.async_discover(): for host in await self._discover_hosts.async_discover():
self.async_process_client( self.async_process_client(
host[DISCOVERY_IP_ADDRESS], host[DISCOVERY_IP_ADDRESS],
@ -295,18 +314,23 @@ class NetworkWatcher(WatcherBase):
class DeviceTrackerWatcher(WatcherBase): class DeviceTrackerWatcher(WatcherBase):
"""Class to watch dhcp data from routers.""" """Class to watch dhcp data from routers."""
def __init__(self, hass, address_data, integration_matchers): def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class.""" """Initialize class."""
super().__init__(hass, address_data, integration_matchers) super().__init__(hass, address_data, integration_matchers)
self._unsub = None self._unsub: Callable[[], None] | None = None
async def async_stop(self): async def async_stop(self) -> None:
"""Stop watching for new device trackers.""" """Stop watching for new device trackers."""
if self._unsub: if self._unsub:
self._unsub() self._unsub()
self._unsub = None self._unsub = None
async def async_start(self): async def async_start(self) -> None:
"""Stop watching for new device trackers.""" """Stop watching for new device trackers."""
self._unsub = async_track_state_added_domain( self._unsub = async_track_state_added_domain(
self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event
@ -315,12 +339,12 @@ class DeviceTrackerWatcher(WatcherBase):
self._async_process_device_state(state) self._async_process_device_state(state)
@callback @callback
def _async_process_device_event(self, event: Event): def _async_process_device_event(self, event: Event) -> None:
"""Process a device tracker state change event.""" """Process a device tracker state change event."""
self._async_process_device_state(event.data["new_state"]) self._async_process_device_state(event.data["new_state"])
@callback @callback
def _async_process_device_state(self, state: State): def _async_process_device_state(self, state: State) -> None:
"""Process a device tracker state.""" """Process a device tracker state."""
if state.state != STATE_HOME: if state.state != STATE_HOME:
return return
@ -343,18 +367,23 @@ class DeviceTrackerWatcher(WatcherBase):
class DeviceTrackerRegisteredWatcher(WatcherBase): class DeviceTrackerRegisteredWatcher(WatcherBase):
"""Class to watch data from device tracker registrations.""" """Class to watch data from device tracker registrations."""
def __init__(self, hass, address_data, integration_matchers): def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class.""" """Initialize class."""
super().__init__(hass, address_data, integration_matchers) super().__init__(hass, address_data, integration_matchers)
self._unsub = None self._unsub: Callable[[], None] | None = None
async def async_stop(self): async def async_stop(self) -> None:
"""Stop watching for device tracker registrations.""" """Stop watching for device tracker registrations."""
if self._unsub: if self._unsub:
self._unsub() self._unsub()
self._unsub = None self._unsub = None
async def async_start(self): async def async_start(self) -> None:
"""Stop watching for device tracker registrations.""" """Stop watching for device tracker registrations."""
self._unsub = async_dispatcher_connect( self._unsub = async_dispatcher_connect(
self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_data self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_data
@ -376,26 +405,32 @@ class DeviceTrackerRegisteredWatcher(WatcherBase):
class DHCPWatcher(WatcherBase): class DHCPWatcher(WatcherBase):
"""Class to watch dhcp requests.""" """Class to watch dhcp requests."""
def __init__(self, hass, address_data, integration_matchers): def __init__(
self,
hass: HomeAssistant,
address_data: dict[str, dict[str, str]],
integration_matchers: list[DHCPMatcher],
) -> None:
"""Initialize class.""" """Initialize class."""
super().__init__(hass, address_data, integration_matchers) super().__init__(hass, address_data, integration_matchers)
self._sniffer = None self._sniffer: AsyncSniffer | None = None
self._started = threading.Event() self._started = threading.Event()
async def async_stop(self): async def async_stop(self) -> None:
"""Stop watching for new device trackers.""" """Stop watching for new device trackers."""
await self.hass.async_add_executor_job(self._stop) await self.hass.async_add_executor_job(self._stop)
def _stop(self): def _stop(self) -> None:
"""Stop the thread.""" """Stop the thread."""
if self._started.is_set(): if self._started.is_set():
assert self._sniffer is not None
self._sniffer.stop() self._sniffer.stop()
async def async_start(self): async def async_start(self) -> None:
"""Start watching for dhcp packets.""" """Start watching for dhcp packets."""
await self.hass.async_add_executor_job(self._start) await self.hass.async_add_executor_job(self._start)
def _start(self): def _start(self) -> None:
"""Start watching for dhcp packets.""" """Start watching for dhcp packets."""
# Local import because importing from scapy has side effects such as opening # Local import because importing from scapy has side effects such as opening
# sockets # sockets
@ -417,20 +452,25 @@ class DHCPWatcher(WatcherBase):
AsyncSniffer, AsyncSniffer,
) )
def _handle_dhcp_packet(packet): def _handle_dhcp_packet(packet: Packet) -> None:
"""Process a dhcp packet.""" """Process a dhcp packet."""
if DHCP not in packet: if DHCP not in packet:
return return
options = packet[DHCP].options options_dict = _dhcp_options_as_dict(packet[DHCP].options)
request_type = _decode_dhcp_option(options, MESSAGE_TYPE) if options_dict.get(MESSAGE_TYPE) != DHCP_REQUEST:
if request_type != DHCP_REQUEST:
# Not a DHCP request # Not a DHCP request
return return
ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src ip_address = options_dict.get(REQUESTED_ADDR) or cast(str, packet[IP].src)
hostname = _decode_dhcp_option(options, HOSTNAME) or "" assert isinstance(ip_address, str)
mac_address = _format_mac(packet[Ether].src) hostname = ""
if (hostname_bytes := options_dict.get(HOSTNAME)) and isinstance(
hostname_bytes, bytes
):
with contextlib.suppress(AttributeError, UnicodeDecodeError):
hostname = hostname_bytes.decode()
mac_address = _format_mac(cast(str, packet[Ether].src))
if ip_address is not None and mac_address is not None: if ip_address is not None and mac_address is not None:
self.process_client(ip_address, hostname, mac_address) self.process_client(ip_address, hostname, mac_address)
@ -470,29 +510,19 @@ class DHCPWatcher(WatcherBase):
self._sniffer.thread.name = self.__class__.__name__ self._sniffer.thread.name = self.__class__.__name__
def _decode_dhcp_option(dhcp_options, key): def _dhcp_options_as_dict(
"""Extract and decode data from a packet option.""" dhcp_options: Iterable[tuple[str, int | bytes | None]]
for option in dhcp_options: ) -> dict[str, str | int | bytes | None]:
if len(option) < 2 or option[0] != key: """Extract data from packet options as a dict."""
continue return {option[0]: option[1] for option in dhcp_options if len(option) >= 2}
value = option[1]
if value is None or key != HOSTNAME:
return value
# hostname is unicode
try:
return value.decode()
except (AttributeError, UnicodeDecodeError):
return None
def _format_mac(mac_address): def _format_mac(mac_address: str) -> str:
"""Format a mac address for matching.""" """Format a mac address for matching."""
return format_mac(mac_address).replace(":", "") return format_mac(mac_address).replace(":", "")
def _verify_l2socket_setup(cap_filter): def _verify_l2socket_setup(cap_filter: str) -> None:
"""Create a socket using the scapy configured l2socket. """Create a socket using the scapy configured l2socket.
Try to create the socket Try to create the socket
@ -504,7 +534,7 @@ def _verify_l2socket_setup(cap_filter):
conf.L2socket(filter=cap_filter) conf.L2socket(filter=cap_filter)
def _verify_working_pcap(cap_filter): def _verify_working_pcap(cap_filter: str) -> None:
"""Verify we can create a packet filter. """Verify we can create a packet filter.
If we cannot create a filter we will be listening for If we cannot create a filter we will be listening for

View File

@ -60,6 +60,24 @@ MAX_LOAD_CONCURRENTLY = 4
MOVED_ZEROCONF_PROPS = ("macaddress", "model", "manufacturer") MOVED_ZEROCONF_PROPS = ("macaddress", "model", "manufacturer")
class DHCPMatcherRequired(TypedDict, total=True):
"""Matcher for the dhcp integration for required fields."""
domain: str
class DHCPMatcherOptional(TypedDict, total=False):
"""Matcher for the dhcp integration for optional fields."""
macaddress: str
hostname: str
registered_devices: bool
class DHCPMatcher(DHCPMatcherRequired, DHCPMatcherOptional):
"""Matcher for the dhcp integration."""
class Manifest(TypedDict, total=False): class Manifest(TypedDict, total=False):
""" """
Integration manifest. Integration manifest.
@ -228,16 +246,16 @@ async def async_get_zeroconf(
return zeroconf return zeroconf
async def async_get_dhcp(hass: HomeAssistant) -> list[dict[str, str | bool]]: async def async_get_dhcp(hass: HomeAssistant) -> list[DHCPMatcher]:
"""Return cached list of dhcp types.""" """Return cached list of dhcp types."""
dhcp: list[dict[str, str | bool]] = DHCP.copy() dhcp = cast(list[DHCPMatcher], DHCP.copy())
integrations = await async_get_custom_components(hass) integrations = await async_get_custom_components(hass)
for integration in integrations.values(): for integration in integrations.values():
if not integration.dhcp: if not integration.dhcp:
continue continue
for entry in integration.dhcp: for entry in integration.dhcp:
dhcp.append({"domain": integration.domain, **entry}) dhcp.append(cast(DHCPMatcher, {"domain": integration.domain, **entry}))
return dhcp return dhcp

View File

@ -549,6 +549,17 @@ no_implicit_optional = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.dhcp.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.dlna_dmr.*] [mypy-homeassistant.components.dlna_dmr.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true