diff --git a/.strict-typing b/.strict-typing index f6908fc39a6..b1f6db6a532 100644 --- a/.strict-typing +++ b/.strict-typing @@ -68,6 +68,7 @@ homeassistant.components.device_automation.* homeassistant.components.device_tracker.* homeassistant.components.devolo_home_control.* homeassistant.components.devolo_home_network.* +homeassistant.components.dhcp.* homeassistant.components.dlna_dmr.* homeassistant.components.dnsip.* homeassistant.components.dsmr.* diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index 0b5f8a49a34..1756f620f46 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -2,6 +2,9 @@ from __future__ import annotations from abc import abstractmethod +import asyncio +from collections.abc import Callable, Iterable +import contextlib from dataclasses import dataclass from datetime import timedelta import fnmatch @@ -9,7 +12,7 @@ from ipaddress import ip_address as make_ip_address import logging import os import threading -from typing import Any, Final +from typing import TYPE_CHECKING, Any, Final, cast from aiodiscover import DiscoverHosts from aiodiscover.discovery import ( @@ -51,12 +54,16 @@ from homeassistant.helpers.event import ( ) from homeassistant.helpers.frame import report from homeassistant.helpers.typing import ConfigType -from homeassistant.loader import async_get_dhcp +from homeassistant.loader import DHCPMatcher, async_get_dhcp from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.network import is_invalid, is_link_local, is_loopback from .const import DOMAIN +if TYPE_CHECKING: + from scapy.packet import Packet + from scapy.sendrecv import AsyncSniffer + FILTER = "udp and (port 67 or 68)" REQUESTED_ADDR = "requested_addr" MESSAGE_TYPE = "message-type" @@ -115,7 +122,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: watchers: list[WatcherBase] = [] address_data: dict[str, dict[str, str]] = {} integration_matchers = await async_get_dhcp(hass) - # For the passive classes we need to start listening # for state changes and connect the dispatchers before # everything else starts up or we will miss events @@ -124,13 +130,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await passive_watcher.async_start() watchers.append(passive_watcher) - async def _initialize(_): + async def _initialize(event: Event) -> None: for active_cls in (DHCPWatcher, NetworkWatcher): active_watcher = active_cls(hass, address_data, integration_matchers) await active_watcher.async_start() watchers.append(active_watcher) - async def _async_stop(*_): + async def _async_stop(event: Event) -> None: for watcher in watchers: await watcher.async_stop() @@ -143,7 +149,12 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: class WatcherBase: """Base class for dhcp and device tracker watching.""" - def __init__(self, hass, address_data, integration_matchers): + def __init__( + self, + hass: HomeAssistant, + address_data: dict[str, dict[str, str]], + integration_matchers: list[DHCPMatcher], + ) -> None: """Initialize class.""" super().__init__() @@ -152,11 +163,11 @@ class WatcherBase: self._address_data = address_data @abstractmethod - async def async_stop(self): + async def async_stop(self) -> None: """Stop the watcher.""" @abstractmethod - async def async_start(self): + async def async_start(self) -> None: """Start the watcher.""" def process_client(self, ip_address: str, hostname: str, mac_address: str) -> None: @@ -197,8 +208,8 @@ class WatcherBase: data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} self._address_data[ip_address] = data - lowercase_hostname = data[HOSTNAME].lower() - uppercase_mac = data[MAC_ADDRESS].upper() + lowercase_hostname = hostname.lower() + uppercase_mac = mac_address.upper() _LOGGER.debug( "Processing updated address data for %s: mac=%s hostname=%s", @@ -218,22 +229,24 @@ class WatcherBase: if entry := self.hass.config_entries.async_get_entry(entry_id): device_domains.add(entry.domain) - for entry in self._integration_matchers: - if entry.get(REGISTERED_DEVICES) and not entry["domain"] in device_domains: + for matcher in self._integration_matchers: + domain = matcher["domain"] + + if matcher.get(REGISTERED_DEVICES) and domain not in device_domains: continue - if MAC_ADDRESS in entry and not fnmatch.fnmatch( - uppercase_mac, entry[MAC_ADDRESS] - ): + if ( + matcher_mac := matcher.get(MAC_ADDRESS) + ) is not None and not fnmatch.fnmatch(uppercase_mac, matcher_mac): continue - if HOSTNAME in entry and not fnmatch.fnmatch( - lowercase_hostname, entry[HOSTNAME] - ): + if ( + matcher_hostname := matcher.get(HOSTNAME) + ) is not None and not fnmatch.fnmatch(lowercase_hostname, matcher_hostname): continue - _LOGGER.debug("Matched %s against %s", data, entry) - matched_domains.add(entry["domain"]) + _LOGGER.debug("Matched %s against %s", data, matcher) + matched_domains.add(domain) for domain in matched_domains: discovery_flow.async_create_flow( @@ -243,7 +256,7 @@ class WatcherBase: DhcpServiceInfo( ip=ip_address, hostname=lowercase_hostname, - macaddress=data[MAC_ADDRESS], + macaddress=mac_address, ), ) @@ -251,14 +264,19 @@ class WatcherBase: class NetworkWatcher(WatcherBase): """Class to query ptr records routers.""" - def __init__(self, hass, address_data, integration_matchers): + def __init__( + self, + hass: HomeAssistant, + address_data: dict[str, dict[str, str]], + integration_matchers: list[DHCPMatcher], + ) -> None: """Initialize class.""" super().__init__(hass, address_data, integration_matchers) - self._unsub = None - self._discover_hosts = None - self._discover_task = None + self._unsub: Callable[[], None] | None = None + self._discover_hosts: DiscoverHosts | None = None + self._discover_task: asyncio.Task | None = None - async def async_stop(self): + async def async_stop(self) -> None: """Stop scanning for new devices on the network.""" if self._unsub: self._unsub() @@ -267,7 +285,7 @@ class NetworkWatcher(WatcherBase): self._discover_task.cancel() self._discover_task = None - async def async_start(self): + async def async_start(self) -> None: """Start scanning for new devices on the network.""" self._discover_hosts = DiscoverHosts() self._unsub = async_track_time_interval( @@ -276,14 +294,15 @@ class NetworkWatcher(WatcherBase): self.async_start_discover() @callback - def async_start_discover(self, *_): + def async_start_discover(self, *_: Any) -> None: """Start a new discovery task if one is not running.""" if self._discover_task and not self._discover_task.done(): return self._discover_task = self.hass.async_create_task(self.async_discover()) - async def async_discover(self): + async def async_discover(self) -> None: """Process discovery.""" + assert self._discover_hosts is not None for host in await self._discover_hosts.async_discover(): self.async_process_client( host[DISCOVERY_IP_ADDRESS], @@ -295,18 +314,23 @@ class NetworkWatcher(WatcherBase): class DeviceTrackerWatcher(WatcherBase): """Class to watch dhcp data from routers.""" - def __init__(self, hass, address_data, integration_matchers): + def __init__( + self, + hass: HomeAssistant, + address_data: dict[str, dict[str, str]], + integration_matchers: list[DHCPMatcher], + ) -> None: """Initialize class.""" super().__init__(hass, address_data, integration_matchers) - self._unsub = None + self._unsub: Callable[[], None] | None = None - async def async_stop(self): + async def async_stop(self) -> None: """Stop watching for new device trackers.""" if self._unsub: self._unsub() self._unsub = None - async def async_start(self): + async def async_start(self) -> None: """Stop watching for new device trackers.""" self._unsub = async_track_state_added_domain( self.hass, [DEVICE_TRACKER_DOMAIN], self._async_process_device_event @@ -315,12 +339,12 @@ class DeviceTrackerWatcher(WatcherBase): self._async_process_device_state(state) @callback - def _async_process_device_event(self, event: Event): + def _async_process_device_event(self, event: Event) -> None: """Process a device tracker state change event.""" self._async_process_device_state(event.data["new_state"]) @callback - def _async_process_device_state(self, state: State): + def _async_process_device_state(self, state: State) -> None: """Process a device tracker state.""" if state.state != STATE_HOME: return @@ -343,18 +367,23 @@ class DeviceTrackerWatcher(WatcherBase): class DeviceTrackerRegisteredWatcher(WatcherBase): """Class to watch data from device tracker registrations.""" - def __init__(self, hass, address_data, integration_matchers): + def __init__( + self, + hass: HomeAssistant, + address_data: dict[str, dict[str, str]], + integration_matchers: list[DHCPMatcher], + ) -> None: """Initialize class.""" super().__init__(hass, address_data, integration_matchers) - self._unsub = None + self._unsub: Callable[[], None] | None = None - async def async_stop(self): + async def async_stop(self) -> None: """Stop watching for device tracker registrations.""" if self._unsub: self._unsub() self._unsub = None - async def async_start(self): + async def async_start(self) -> None: """Stop watching for device tracker registrations.""" self._unsub = async_dispatcher_connect( self.hass, CONNECTED_DEVICE_REGISTERED, self._async_process_device_data @@ -376,26 +405,32 @@ class DeviceTrackerRegisteredWatcher(WatcherBase): class DHCPWatcher(WatcherBase): """Class to watch dhcp requests.""" - def __init__(self, hass, address_data, integration_matchers): + def __init__( + self, + hass: HomeAssistant, + address_data: dict[str, dict[str, str]], + integration_matchers: list[DHCPMatcher], + ) -> None: """Initialize class.""" super().__init__(hass, address_data, integration_matchers) - self._sniffer = None + self._sniffer: AsyncSniffer | None = None self._started = threading.Event() - async def async_stop(self): + async def async_stop(self) -> None: """Stop watching for new device trackers.""" await self.hass.async_add_executor_job(self._stop) - def _stop(self): + def _stop(self) -> None: """Stop the thread.""" if self._started.is_set(): + assert self._sniffer is not None self._sniffer.stop() - async def async_start(self): + async def async_start(self) -> None: """Start watching for dhcp packets.""" await self.hass.async_add_executor_job(self._start) - def _start(self): + def _start(self) -> None: """Start watching for dhcp packets.""" # Local import because importing from scapy has side effects such as opening # sockets @@ -417,20 +452,25 @@ class DHCPWatcher(WatcherBase): AsyncSniffer, ) - def _handle_dhcp_packet(packet): + def _handle_dhcp_packet(packet: Packet) -> None: """Process a dhcp packet.""" if DHCP not in packet: return - options = packet[DHCP].options - request_type = _decode_dhcp_option(options, MESSAGE_TYPE) - if request_type != DHCP_REQUEST: + options_dict = _dhcp_options_as_dict(packet[DHCP].options) + if options_dict.get(MESSAGE_TYPE) != DHCP_REQUEST: # Not a DHCP request return - ip_address = _decode_dhcp_option(options, REQUESTED_ADDR) or packet[IP].src - hostname = _decode_dhcp_option(options, HOSTNAME) or "" - mac_address = _format_mac(packet[Ether].src) + ip_address = options_dict.get(REQUESTED_ADDR) or cast(str, packet[IP].src) + assert isinstance(ip_address, str) + hostname = "" + if (hostname_bytes := options_dict.get(HOSTNAME)) and isinstance( + hostname_bytes, bytes + ): + with contextlib.suppress(AttributeError, UnicodeDecodeError): + hostname = hostname_bytes.decode() + mac_address = _format_mac(cast(str, packet[Ether].src)) if ip_address is not None and mac_address is not None: self.process_client(ip_address, hostname, mac_address) @@ -470,29 +510,19 @@ class DHCPWatcher(WatcherBase): self._sniffer.thread.name = self.__class__.__name__ -def _decode_dhcp_option(dhcp_options, key): - """Extract and decode data from a packet option.""" - for option in dhcp_options: - if len(option) < 2 or option[0] != key: - continue - - value = option[1] - if value is None or key != HOSTNAME: - return value - - # hostname is unicode - try: - return value.decode() - except (AttributeError, UnicodeDecodeError): - return None +def _dhcp_options_as_dict( + dhcp_options: Iterable[tuple[str, int | bytes | None]] +) -> dict[str, str | int | bytes | None]: + """Extract data from packet options as a dict.""" + return {option[0]: option[1] for option in dhcp_options if len(option) >= 2} -def _format_mac(mac_address): +def _format_mac(mac_address: str) -> str: """Format a mac address for matching.""" return format_mac(mac_address).replace(":", "") -def _verify_l2socket_setup(cap_filter): +def _verify_l2socket_setup(cap_filter: str) -> None: """Create a socket using the scapy configured l2socket. Try to create the socket @@ -504,7 +534,7 @@ def _verify_l2socket_setup(cap_filter): conf.L2socket(filter=cap_filter) -def _verify_working_pcap(cap_filter): +def _verify_working_pcap(cap_filter: str) -> None: """Verify we can create a packet filter. If we cannot create a filter we will be listening for diff --git a/homeassistant/loader.py b/homeassistant/loader.py index f5c68897e2e..8e4521eddba 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -60,6 +60,24 @@ MAX_LOAD_CONCURRENTLY = 4 MOVED_ZEROCONF_PROPS = ("macaddress", "model", "manufacturer") +class DHCPMatcherRequired(TypedDict, total=True): + """Matcher for the dhcp integration for required fields.""" + + domain: str + + +class DHCPMatcherOptional(TypedDict, total=False): + """Matcher for the dhcp integration for optional fields.""" + + macaddress: str + hostname: str + registered_devices: bool + + +class DHCPMatcher(DHCPMatcherRequired, DHCPMatcherOptional): + """Matcher for the dhcp integration.""" + + class Manifest(TypedDict, total=False): """ Integration manifest. @@ -228,16 +246,16 @@ async def async_get_zeroconf( return zeroconf -async def async_get_dhcp(hass: HomeAssistant) -> list[dict[str, str | bool]]: +async def async_get_dhcp(hass: HomeAssistant) -> list[DHCPMatcher]: """Return cached list of dhcp types.""" - dhcp: list[dict[str, str | bool]] = DHCP.copy() + dhcp = cast(list[DHCPMatcher], DHCP.copy()) integrations = await async_get_custom_components(hass) for integration in integrations.values(): if not integration.dhcp: continue for entry in integration.dhcp: - dhcp.append({"domain": integration.domain, **entry}) + dhcp.append(cast(DHCPMatcher, {"domain": integration.domain, **entry})) return dhcp diff --git a/mypy.ini b/mypy.ini index 55d608c8628..63c1419d175 100644 --- a/mypy.ini +++ b/mypy.ini @@ -549,6 +549,17 @@ no_implicit_optional = true warn_return_any = true warn_unreachable = true +[mypy-homeassistant.components.dhcp.*] +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_subclassing_any = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +no_implicit_optional = true +warn_return_any = true +warn_unreachable = true + [mypy-homeassistant.components.dlna_dmr.*] check_untyped_defs = true disallow_incomplete_defs = true