mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 13:17:32 +00:00
Add DHCP discovery subscribe websocket API
This commit is contained in:
parent
9d02436a72
commit
3356ee5ded
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from fnmatch import translate
|
||||
@ -11,7 +12,7 @@ from functools import lru_cache, partial
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Final
|
||||
from typing import TYPE_CHECKING, Any, Final, TypedDict
|
||||
|
||||
import aiodhcpwatcher
|
||||
from aiodiscover import DiscoverHosts
|
||||
@ -38,6 +39,7 @@ from homeassistant.const import (
|
||||
STATE_HOME,
|
||||
)
|
||||
from homeassistant.core import (
|
||||
CALLBACK_TYPE,
|
||||
Event,
|
||||
EventStateChangedData,
|
||||
HomeAssistant,
|
||||
@ -65,7 +67,9 @@ from homeassistant.helpers.event import (
|
||||
from homeassistant.helpers.service_info.dhcp import DhcpServiceInfo as _DhcpServiceInfo
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import DHCPMatcher, async_get_dhcp
|
||||
from homeassistant.util.hass_dict import HassKey
|
||||
|
||||
from . import websocket_api
|
||||
from .const import DOMAIN
|
||||
|
||||
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
@ -96,6 +100,27 @@ class DhcpMatchers:
|
||||
oui_matchers: dict[str, list[DHCPMatcher]]
|
||||
|
||||
|
||||
class DHCPAddressData(TypedDict):
|
||||
"""Typed dict for DHCP address data."""
|
||||
|
||||
hostname: str
|
||||
ip: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class DHCPData:
|
||||
"""Data for the dhcp component."""
|
||||
|
||||
integration_matchers: DhcpMatchers
|
||||
callbacks: set[Callable[[_DhcpServiceInfo], None]] = dataclasses.field(
|
||||
default_factory=set
|
||||
)
|
||||
address_data: dict[str, DHCPAddressData] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
DATA_DHCP: HassKey[DHCPData] = HassKey(DOMAIN)
|
||||
|
||||
|
||||
def async_index_integration_matchers(
|
||||
integration_matchers: list[DHCPMatcher],
|
||||
) -> DhcpMatchers:
|
||||
@ -131,38 +156,59 @@ def async_index_integration_matchers(
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_dhcp_callback_internal(
|
||||
hass: HomeAssistant,
|
||||
callback_: Callable[[_DhcpServiceInfo], None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register a dhcp callback.
|
||||
|
||||
For internal use only.
|
||||
This is not intended for use by integrations.
|
||||
"""
|
||||
callbacks = hass.data[DATA_DHCP].callbacks
|
||||
callbacks.add(callback_)
|
||||
return partial(callbacks.remove, callback_)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_address_data_internal(
|
||||
hass: HomeAssistant,
|
||||
) -> dict[str, DHCPAddressData]:
|
||||
"""Get the address data."""
|
||||
return hass.data[DATA_DHCP].address_data
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the dhcp component."""
|
||||
watchers: list[WatcherBase] = []
|
||||
address_data: dict[str, dict[str, str]] = {}
|
||||
integration_matchers = async_index_integration_matchers(await async_get_dhcp(hass))
|
||||
dhcp_data = DHCPData(integration_matchers=integration_matchers)
|
||||
hass.data[DATA_DHCP] = dhcp_data
|
||||
websocket_api.async_setup(hass)
|
||||
watchers: list[WatcherBase] = []
|
||||
# For the passive classes we need to start listening
|
||||
# for state changes and connect the dispatchers before
|
||||
# everything else starts up or we will miss events
|
||||
device_watcher = DeviceTrackerWatcher(hass, address_data, integration_matchers)
|
||||
device_watcher = DeviceTrackerWatcher(hass, dhcp_data)
|
||||
device_watcher.async_start()
|
||||
watchers.append(device_watcher)
|
||||
|
||||
device_tracker_registered_watcher = DeviceTrackerRegisteredWatcher(
|
||||
hass, address_data, integration_matchers
|
||||
)
|
||||
device_tracker_registered_watcher = DeviceTrackerRegisteredWatcher(hass, dhcp_data)
|
||||
device_tracker_registered_watcher.async_start()
|
||||
watchers.append(device_tracker_registered_watcher)
|
||||
|
||||
async def _async_initialize(event: Event) -> None:
|
||||
await aiodhcpwatcher.async_init()
|
||||
|
||||
network_watcher = NetworkWatcher(hass, address_data, integration_matchers)
|
||||
network_watcher = NetworkWatcher(hass, dhcp_data)
|
||||
network_watcher.async_start()
|
||||
watchers.append(network_watcher)
|
||||
|
||||
dhcp_watcher = DHCPWatcher(hass, address_data, integration_matchers)
|
||||
dhcp_watcher = DHCPWatcher(hass, dhcp_data)
|
||||
await dhcp_watcher.async_start()
|
||||
watchers.append(dhcp_watcher)
|
||||
|
||||
rediscovery_watcher = RediscoveryWatcher(
|
||||
hass, address_data, integration_matchers
|
||||
)
|
||||
rediscovery_watcher = RediscoveryWatcher(hass, dhcp_data)
|
||||
rediscovery_watcher.async_start()
|
||||
watchers.append(rediscovery_watcher)
|
||||
|
||||
@ -180,18 +226,13 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
class WatcherBase:
|
||||
"""Base class for dhcp and device tracker watching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: DhcpMatchers,
|
||||
) -> None:
|
||||
def __init__(self, hass: HomeAssistant, dhcp_data: DHCPData) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__()
|
||||
|
||||
self.hass = hass
|
||||
self._integration_matchers = integration_matchers
|
||||
self._address_data = address_data
|
||||
self._callbacks = dhcp_data.callbacks
|
||||
self._integration_matchers = dhcp_data.integration_matchers
|
||||
self._address_data = dhcp_data.address_data
|
||||
self._unsub: Callable[[], None] | None = None
|
||||
|
||||
@callback
|
||||
@ -230,18 +271,18 @@ class WatcherBase:
|
||||
mac_address = formatted_mac.replace(":", "")
|
||||
compressed_ip_address = made_ip_address.compressed
|
||||
|
||||
data = self._address_data.get(mac_address)
|
||||
current_data = self._address_data.get(mac_address)
|
||||
if (
|
||||
not force
|
||||
and data
|
||||
and data[IP_ADDRESS] == compressed_ip_address
|
||||
and data[HOSTNAME].startswith(hostname)
|
||||
and current_data
|
||||
and current_data[IP_ADDRESS] == compressed_ip_address
|
||||
and current_data[HOSTNAME].startswith(hostname)
|
||||
):
|
||||
# If the address data is the same no need
|
||||
# to process it
|
||||
return
|
||||
|
||||
data = {IP_ADDRESS: compressed_ip_address, HOSTNAME: hostname}
|
||||
data: DHCPAddressData = {IP_ADDRESS: compressed_ip_address, HOSTNAME: hostname}
|
||||
self._address_data[mac_address] = data
|
||||
|
||||
lowercase_hostname = hostname.lower()
|
||||
@ -287,8 +328,19 @@ class WatcherBase:
|
||||
_LOGGER.debug("Matched %s against %s", data, matcher)
|
||||
matched_domains.add(domain)
|
||||
|
||||
if not matched_domains:
|
||||
return # avoid creating DiscoveryKey if there are no matches
|
||||
service_info: _DhcpServiceInfo | None = None
|
||||
if self._callbacks or matched_domains:
|
||||
service_info = _DhcpServiceInfo(
|
||||
ip=ip_address,
|
||||
hostname=lowercase_hostname,
|
||||
macaddress=mac_address,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert service_info is not None
|
||||
|
||||
for callback_ in self._callbacks:
|
||||
callback_(service_info)
|
||||
|
||||
discovery_key = DiscoveryKey(
|
||||
domain=DOMAIN,
|
||||
@ -300,11 +352,7 @@ class WatcherBase:
|
||||
self.hass,
|
||||
domain,
|
||||
{"source": config_entries.SOURCE_DHCP},
|
||||
_DhcpServiceInfo(
|
||||
ip=ip_address,
|
||||
hostname=lowercase_hostname,
|
||||
macaddress=mac_address,
|
||||
),
|
||||
service_info,
|
||||
discovery_key=discovery_key,
|
||||
)
|
||||
|
||||
@ -315,11 +363,10 @@ class NetworkWatcher(WatcherBase):
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
address_data: dict[str, dict[str, str]],
|
||||
integration_matchers: DhcpMatchers,
|
||||
dhcp_data: DHCPData,
|
||||
) -> None:
|
||||
"""Initialize class."""
|
||||
super().__init__(hass, address_data, integration_matchers)
|
||||
super().__init__(hass, dhcp_data)
|
||||
self._discover_hosts: DiscoverHosts | None = None
|
||||
self._discover_task: asyncio.Task | None = None
|
||||
|
||||
|
123
homeassistant/components/dhcp/websocket_api.py
Normal file
123
homeassistant/components/dhcp/websocket_api.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""The dhcp integration websocket apis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
|
||||
from . import (
|
||||
HOSTNAME,
|
||||
IP_ADDRESS,
|
||||
_DhcpServiceInfo,
|
||||
async_get_address_data_internal,
|
||||
async_register_dhcp_callback_internal,
|
||||
)
|
||||
|
||||
|
||||
class DHCPDiscovery(TypedDict):
|
||||
"""Typed dict for DHCP discovery."""
|
||||
|
||||
mac_address: str
|
||||
hostname: str
|
||||
ip_address: str
|
||||
|
||||
|
||||
@callback
|
||||
def async_setup(hass: HomeAssistant) -> None:
|
||||
"""Set up the DHCP websocket API."""
|
||||
websocket_api.async_register_command(hass, ws_subscribe_discovery)
|
||||
|
||||
|
||||
def serialize_service_info(service_info: _DhcpServiceInfo) -> DHCPDiscovery:
|
||||
"""Serialize a _DhcpServiceInfo object."""
|
||||
serialized: DHCPDiscovery = {
|
||||
"mac_address": service_info.macaddress,
|
||||
"hostname": service_info.hostname,
|
||||
"ip_address": service_info.ip,
|
||||
}
|
||||
return serialized
|
||||
|
||||
|
||||
class _DiscoverySubscription:
|
||||
"""Class to hold and manage the subscription data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
ws_msg_id: int,
|
||||
) -> None:
|
||||
"""Initialize the subscription data."""
|
||||
self.hass = hass
|
||||
self.ws_msg_id = ws_msg_id
|
||||
self.connection = connection
|
||||
|
||||
@callback
|
||||
def async_start(self) -> None:
|
||||
"""Start the subscription."""
|
||||
connection = self.connection
|
||||
connection.subscriptions[self.ws_msg_id] = (
|
||||
async_register_dhcp_callback_internal(
|
||||
self.hass,
|
||||
self._async_on_discovery,
|
||||
)
|
||||
)
|
||||
connection.send_message(
|
||||
json_bytes(websocket_api.result_message(self.ws_msg_id))
|
||||
)
|
||||
self._async_send_current_address_data()
|
||||
|
||||
def _async_event_message(self, message: dict[str, Any]) -> None:
|
||||
self.connection.send_message(
|
||||
json_bytes(websocket_api.event_message(self.ws_msg_id, message))
|
||||
)
|
||||
|
||||
def _async_send_current_address_data(self) -> None:
|
||||
"""Send the current address data."""
|
||||
address_data = async_get_address_data_internal(self.hass)
|
||||
self._async_event_message(
|
||||
{
|
||||
"add": [
|
||||
{
|
||||
"address": mac_address,
|
||||
"hostname": data[HOSTNAME],
|
||||
"ip_address": data[IP_ADDRESS],
|
||||
}
|
||||
for mac_address, data in address_data.items()
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
def _async_added(self, service_infos: list[_DhcpServiceInfo]) -> None:
|
||||
self._async_event_message(
|
||||
{
|
||||
"add": [
|
||||
serialize_service_info(service_info)
|
||||
for service_info in service_infos
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@callback
|
||||
def _async_on_discovery(self, service_info: _DhcpServiceInfo) -> None:
|
||||
"""Handle the callback."""
|
||||
self._async_event_message({"add": [serialize_service_info(service_info)]})
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "dhcp/subscribe_discovery",
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def ws_subscribe_discovery(
|
||||
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle subscribe discovery websocket command."""
|
||||
_DiscoverySubscription(hass, connection, msg["id"]).async_start()
|
Loading…
x
Reference in New Issue
Block a user