From 3356ee5dedc11726fe093e27f7b1ec53e97431c4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 16 Apr 2025 08:46:37 -1000 Subject: [PATCH] Add DHCP discovery subscribe websocket API --- homeassistant/components/dhcp/__init__.py | 119 ++++++++++++----- .../components/dhcp/websocket_api.py | 123 ++++++++++++++++++ 2 files changed, 206 insertions(+), 36 deletions(-) create mode 100644 homeassistant/components/dhcp/websocket_api.py diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index a11a0b262b0..d2a248fa516 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable +import dataclasses from dataclasses import dataclass from datetime import timedelta from fnmatch import translate @@ -11,7 +12,7 @@ from functools import lru_cache, partial import itertools import logging import re -from typing import Any, Final +from typing import TYPE_CHECKING, Any, Final, TypedDict import aiodhcpwatcher from aiodiscover import DiscoverHosts @@ -38,6 +39,7 @@ from homeassistant.const import ( STATE_HOME, ) from homeassistant.core import ( + CALLBACK_TYPE, Event, EventStateChangedData, HomeAssistant, @@ -65,7 +67,9 @@ from homeassistant.helpers.event import ( from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo as _DhcpServiceInfo from homeassistant.helpers.typing import ConfigType from homeassistant.loader import DHCPMatcher, async_get_dhcp +from homeassistant.util.hass_dict import HassKey +from . import websocket_api from .const import DOMAIN CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) @@ -96,6 +100,27 @@ class DhcpMatchers: oui_matchers: dict[str, list[DHCPMatcher]] +class DHCPAddressData(TypedDict): + """Typed dict for DHCP address data.""" + + hostname: str + ip: str + + +@dataclasses.dataclass(slots=True) +class DHCPData: + """Data for the dhcp component.""" + + integration_matchers: DhcpMatchers + callbacks: set[Callable[[_DhcpServiceInfo], None]] = dataclasses.field( + default_factory=set + ) + address_data: dict[str, DHCPAddressData] = dataclasses.field(default_factory=dict) + + +DATA_DHCP: HassKey[DHCPData] = HassKey(DOMAIN) + + def async_index_integration_matchers( integration_matchers: list[DHCPMatcher], ) -> DhcpMatchers: @@ -131,38 +156,59 @@ def async_index_integration_matchers( ) +@callback +def async_register_dhcp_callback_internal( + hass: HomeAssistant, + callback_: Callable[[_DhcpServiceInfo], None], +) -> CALLBACK_TYPE: + """Register a dhcp callback. + + For internal use only. + This is not intended for use by integrations. + """ + callbacks = hass.data[DATA_DHCP].callbacks + callbacks.add(callback_) + return partial(callbacks.remove, callback_) + + +@callback +def async_get_address_data_internal( + hass: HomeAssistant, +) -> dict[str, DHCPAddressData]: + """Get the address data.""" + return hass.data[DATA_DHCP].address_data + + async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the dhcp component.""" - watchers: list[WatcherBase] = [] - address_data: dict[str, dict[str, str]] = {} integration_matchers = async_index_integration_matchers(await async_get_dhcp(hass)) + dhcp_data = DHCPData(integration_matchers=integration_matchers) + hass.data[DATA_DHCP] = dhcp_data + websocket_api.async_setup(hass) + watchers: list[WatcherBase] = [] # For the passive classes we need to start listening # for state changes and connect the dispatchers before # everything else starts up or we will miss events - device_watcher = DeviceTrackerWatcher(hass, address_data, integration_matchers) + device_watcher = DeviceTrackerWatcher(hass, dhcp_data) device_watcher.async_start() watchers.append(device_watcher) - device_tracker_registered_watcher = DeviceTrackerRegisteredWatcher( - hass, address_data, integration_matchers - ) + device_tracker_registered_watcher = DeviceTrackerRegisteredWatcher(hass, dhcp_data) device_tracker_registered_watcher.async_start() watchers.append(device_tracker_registered_watcher) async def _async_initialize(event: Event) -> None: await aiodhcpwatcher.async_init() - network_watcher = NetworkWatcher(hass, address_data, integration_matchers) + network_watcher = NetworkWatcher(hass, dhcp_data) network_watcher.async_start() watchers.append(network_watcher) - dhcp_watcher = DHCPWatcher(hass, address_data, integration_matchers) + dhcp_watcher = DHCPWatcher(hass, dhcp_data) await dhcp_watcher.async_start() watchers.append(dhcp_watcher) - rediscovery_watcher = RediscoveryWatcher( - hass, address_data, integration_matchers - ) + rediscovery_watcher = RediscoveryWatcher(hass, dhcp_data) rediscovery_watcher.async_start() watchers.append(rediscovery_watcher) @@ -180,18 +226,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: class WatcherBase: """Base class for dhcp and device tracker watching.""" - def __init__( - self, - hass: HomeAssistant, - address_data: dict[str, dict[str, str]], - integration_matchers: DhcpMatchers, - ) -> None: + def __init__(self, hass: HomeAssistant, dhcp_data: DHCPData) -> None: """Initialize class.""" super().__init__() - self.hass = hass - self._integration_matchers = integration_matchers - self._address_data = address_data + self._callbacks = dhcp_data.callbacks + self._integration_matchers = dhcp_data.integration_matchers + self._address_data = dhcp_data.address_data self._unsub: Callable[[], None] | None = None @callback @@ -230,18 +271,18 @@ class WatcherBase: mac_address = formatted_mac.replace(":", "") compressed_ip_address = made_ip_address.compressed - data = self._address_data.get(mac_address) + current_data = self._address_data.get(mac_address) if ( not force - and data - and data[IP_ADDRESS] == compressed_ip_address - and data[HOSTNAME].startswith(hostname) + and current_data + and current_data[IP_ADDRESS] == compressed_ip_address + and current_data[HOSTNAME].startswith(hostname) ): # If the address data is the same no need # to process it return - data = {IP_ADDRESS: compressed_ip_address, HOSTNAME: hostname} + data: DHCPAddressData = {IP_ADDRESS: compressed_ip_address, HOSTNAME: hostname} self._address_data[mac_address] = data lowercase_hostname = hostname.lower() @@ -287,8 +328,19 @@ class WatcherBase: _LOGGER.debug("Matched %s against %s", data, matcher) matched_domains.add(domain) - if not matched_domains: - return # avoid creating DiscoveryKey if there are no matches + service_info: _DhcpServiceInfo | None = None + if self._callbacks or matched_domains: + service_info = _DhcpServiceInfo( + ip=ip_address, + hostname=lowercase_hostname, + macaddress=mac_address, + ) + + if TYPE_CHECKING: + assert service_info is not None + + for callback_ in self._callbacks: + callback_(service_info) discovery_key = DiscoveryKey( domain=DOMAIN, @@ -300,11 +352,7 @@ class WatcherBase: self.hass, domain, {"source": config_entries.SOURCE_DHCP}, - _DhcpServiceInfo( - ip=ip_address, - hostname=lowercase_hostname, - macaddress=mac_address, - ), + service_info, discovery_key=discovery_key, ) @@ -315,11 +363,10 @@ class NetworkWatcher(WatcherBase): def __init__( self, hass: HomeAssistant, - address_data: dict[str, dict[str, str]], - integration_matchers: DhcpMatchers, + dhcp_data: DHCPData, ) -> None: """Initialize class.""" - super().__init__(hass, address_data, integration_matchers) + super().__init__(hass, dhcp_data) self._discover_hosts: DiscoverHosts | None = None self._discover_task: asyncio.Task | None = None diff --git a/homeassistant/components/dhcp/websocket_api.py b/homeassistant/components/dhcp/websocket_api.py new file mode 100644 index 00000000000..82ccf6c162b --- /dev/null +++ b/homeassistant/components/dhcp/websocket_api.py @@ -0,0 +1,123 @@ +"""The dhcp integration websocket apis.""" + +from __future__ import annotations + +from typing import Any, TypedDict + +import voluptuous as vol + +from homeassistant.components import websocket_api +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.json import json_bytes + +from . import ( + HOSTNAME, + IP_ADDRESS, + _DhcpServiceInfo, + async_get_address_data_internal, + async_register_dhcp_callback_internal, +) + + +class DHCPDiscovery(TypedDict): + """Typed dict for DHCP discovery.""" + + mac_address: str + hostname: str + ip_address: str + + +@callback +def async_setup(hass: HomeAssistant) -> None: + """Set up the DHCP websocket API.""" + websocket_api.async_register_command(hass, ws_subscribe_discovery) + + +def serialize_service_info(service_info: _DhcpServiceInfo) -> DHCPDiscovery: + """Serialize a _DhcpServiceInfo object.""" + serialized: DHCPDiscovery = { + "mac_address": service_info.macaddress, + "hostname": service_info.hostname, + "ip_address": service_info.ip, + } + return serialized + + +class _DiscoverySubscription: + """Class to hold and manage the subscription data.""" + + def __init__( + self, + hass: HomeAssistant, + connection: websocket_api.ActiveConnection, + ws_msg_id: int, + ) -> None: + """Initialize the subscription data.""" + self.hass = hass + self.ws_msg_id = ws_msg_id + self.connection = connection + + @callback + def async_start(self) -> None: + """Start the subscription.""" + connection = self.connection + connection.subscriptions[self.ws_msg_id] = ( + async_register_dhcp_callback_internal( + self.hass, + self._async_on_discovery, + ) + ) + connection.send_message( + json_bytes(websocket_api.result_message(self.ws_msg_id)) + ) + self._async_send_current_address_data() + + def _async_event_message(self, message: dict[str, Any]) -> None: + self.connection.send_message( + json_bytes(websocket_api.event_message(self.ws_msg_id, message)) + ) + + def _async_send_current_address_data(self) -> None: + """Send the current address data.""" + address_data = async_get_address_data_internal(self.hass) + self._async_event_message( + { + "add": [ + { + "address": mac_address, + "hostname": data[HOSTNAME], + "ip_address": data[IP_ADDRESS], + } + for mac_address, data in address_data.items() + ] + } + ) + + def _async_added(self, service_infos: list[_DhcpServiceInfo]) -> None: + self._async_event_message( + { + "add": [ + serialize_service_info(service_info) + for service_info in service_infos + ] + } + ) + + @callback + def _async_on_discovery(self, service_info: _DhcpServiceInfo) -> None: + """Handle the callback.""" + self._async_event_message({"add": [serialize_service_info(service_info)]}) + + +@websocket_api.require_admin +@websocket_api.websocket_command( + { + vol.Required("type"): "dhcp/subscribe_discovery", + } +) +@websocket_api.async_response +async def ws_subscribe_discovery( + hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any] +) -> None: + """Handle subscribe discovery websocket command.""" + _DiscoverySubscription(hass, connection, msg["id"]).async_start()